You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/01/21 00:07:05 UTC
[airflow] 13/23: Un-ignore DeprecationWarning (#20322)
This is an automated email from the ASF dual-hosted git repository.
jedcunningham pushed a commit to branch v2-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit a25d7cef7f10be25b2446abe641c0b5822e9d9dc
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Tue Dec 21 18:00:46 2021 +0800
Un-ignore DeprecationWarning (#20322)
(cherry picked from commit 9876e19273cd56dc53d3a4e287db43acbfa65c4b)
---
airflow/models/taskinstance.py | 41 +++++------
airflow/operators/datetime.py | 2 +-
airflow/operators/python.py | 26 ++++---
airflow/operators/weekday.py | 2 +-
airflow/providers/http/operators/http.py | 10 +--
airflow/providers/http/sensors/http.py | 7 +-
airflow/sensors/external_task.py | 24 +++----
airflow/sensors/weekday.py | 2 +-
airflow/utils/context.py | 33 +++++++++
airflow/utils/context.pyi | 6 +-
airflow/utils/helpers.py | 2 +-
.../log/task_handler_with_custom_formatter.py | 4 +-
airflow/utils/operator_helpers.py | 84 +++++++++++++++++-----
scripts/ci/kubernetes/ci_run_kubernetes_tests.sh | 7 +-
scripts/in_container/entrypoint_ci.sh | 2 -
tests/cli/commands/test_task_command.py | 2 +
tests/core/test_core.py | 21 +++---
tests/operators/test_email.py | 2 +-
tests/operators/test_python.py | 9 ++-
tests/operators/test_trigger_dagrun.py | 2 +-
tests/providers/http/sensors/test_http.py | 4 +-
tests/sensors/test_external_task_sensor.py | 8 +--
tests/utils/test_log_handlers.py | 6 +-
23 files changed, 195 insertions(+), 111 deletions(-)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index f37cada..716167c 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -86,7 +86,7 @@ from airflow.typing_compat import Literal
from airflow.utils import timezone
from airflow.utils.context import ConnectionAccessor, Context, VariableAccessor
from airflow.utils.email import send_email
-from airflow.utils.helpers import is_container
+from airflow.utils.helpers import is_container, render_template_to_string
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
@@ -2016,7 +2016,7 @@ class TaskInstance(Base, LoggingMixin):
sanitized_pod = ApiClient().sanitize_for_serialization(pod)
return sanitized_pod
- def get_email_subject_content(self, exception):
+ def get_email_subject_content(self, exception: BaseException) -> Tuple[str, str, str]:
"""Get the email subject content for exceptions."""
# For a ti from DB (without ti.task), return the default value
# Reuse it for smart sensor to send default email alert
@@ -2043,18 +2043,18 @@ class TaskInstance(Base, LoggingMixin):
'Mark success: <a href="{{ti.mark_success_url}}">Link</a><br>'
)
+ # This function is called after changing the state from State.RUNNING,
+ # so we need to subtract 1 from self.try_number here.
+ current_try_number = self.try_number - 1
+ additional_context = {
+ "exception": exception,
+ "exception_html": exception_html,
+ "try_number": current_try_number,
+ "max_tries": self.max_tries,
+ }
+
if use_default:
- jinja_context = {'ti': self}
- # This function is called after changing the state
- # from State.RUNNING so need to subtract 1 from self.try_number.
- jinja_context.update(
- dict(
- exception=exception,
- exception_html=exception_html,
- try_number=self.try_number - 1,
- max_tries=self.max_tries,
- )
- )
+ jinja_context = {"ti": self, **additional_context}
jinja_env = jinja2.Environment(
loader=jinja2.FileSystemLoader(os.path.dirname(__file__)), autoescape=True
)
@@ -2064,24 +2064,15 @@ class TaskInstance(Base, LoggingMixin):
else:
jinja_context = self.get_template_context()
-
- jinja_context.update(
- dict(
- exception=exception,
- exception_html=exception_html,
- try_number=self.try_number - 1,
- max_tries=self.max_tries,
- )
- )
-
+ jinja_context.update(additional_context)
jinja_env = self.task.get_template_env()
- def render(key, content):
+ def render(key: str, content: str) -> str:
if conf.has_option('email', key):
path = conf.get('email', key)
with open(path) as f:
content = f.read()
- return jinja_env.from_string(content).render(**jinja_context)
+ return render_template_to_string(jinja_env.from_string(content), jinja_context)
subject = render('subject_template', default_subject)
html_content = render('html_content_template', default_html_content)
diff --git a/airflow/operators/datetime.py b/airflow/operators/datetime.py
index 6b1acf7..15d4300 100644
--- a/airflow/operators/datetime.py
+++ b/airflow/operators/datetime.py
@@ -72,7 +72,7 @@ class BranchDateTimeOperator(BaseBranchOperator):
def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]:
if self.use_task_execution_date is True:
- now = timezone.make_naive(context["execution_date"], self.dag.timezone)
+ now = timezone.make_naive(context["logical_date"], self.dag.timezone)
else:
now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)
diff --git a/airflow/operators/python.py b/airflow/operators/python.py
index 5b552b8..8e51536 100644
--- a/airflow/operators/python.py
+++ b/airflow/operators/python.py
@@ -24,7 +24,7 @@ import types
import warnings
from tempfile import TemporaryDirectory
from textwrap import dedent
-from typing import Callable, Dict, Iterable, List, Optional, Union
+from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Union
import dill
@@ -33,7 +33,7 @@ from airflow.models import BaseOperator
from airflow.models.skipmixin import SkipMixin
from airflow.models.taskinstance import _CURRENT_CONTEXT
from airflow.utils.context import Context
-from airflow.utils.operator_helpers import determine_kwargs
+from airflow.utils.operator_helpers import KeywordParameters
from airflow.utils.process_utils import execute_in_subprocess
from airflow.utils.python_virtualenv import prepare_virtualenv, write_python_script
@@ -142,8 +142,8 @@ class PythonOperator(BaseOperator):
self,
*,
python_callable: Callable,
- op_args: Optional[List] = None,
- op_kwargs: Optional[Dict] = None,
+ op_args: Optional[Collection[Any]] = None,
+ op_kwargs: Optional[Mapping[str, Any]] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
**kwargs,
@@ -159,7 +159,7 @@ class PythonOperator(BaseOperator):
if not callable(python_callable):
raise AirflowException('`python_callable` param must be callable')
self.python_callable = python_callable
- self.op_args = op_args or []
+ self.op_args = op_args or ()
self.op_kwargs = op_kwargs or {}
self.templates_dict = templates_dict
if templates_exts:
@@ -169,12 +169,15 @@ class PythonOperator(BaseOperator):
context.update(self.op_kwargs)
context['templates_dict'] = self.templates_dict
- self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+ self.op_kwargs = self.determine_kwargs(context)
return_value = self.execute_callable()
self.log.info("Done. Returned value was: %s", return_value)
return return_value
+ def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
+ return KeywordParameters.determine(self.python_callable, self.op_args, context).unpacking()
+
def execute_callable(self):
"""
Calls the python callable with the given arguments.
@@ -241,11 +244,11 @@ class ShortCircuitOperator(PythonOperator, SkipMixin):
self.log.info('Skipping downstream tasks...')
- downstream_tasks = context['task'].get_flat_relatives(upstream=False)
+ downstream_tasks = context["task"].get_flat_relatives(upstream=False)
self.log.debug("Downstream task_ids %s", downstream_tasks)
if downstream_tasks:
- self.skip(context['dag_run'], context['ti'].execution_date, downstream_tasks)
+ self.skip(context["dag_run"], context["logical_date"], downstream_tasks)
self.log.info("Done.")
@@ -345,8 +348,8 @@ class PythonVirtualenvOperator(PythonOperator):
python_version: Optional[Union[str, int, float]] = None,
use_dill: bool = False,
system_site_packages: bool = True,
- op_args: Optional[List] = None,
- op_kwargs: Optional[Dict] = None,
+ op_args: Optional[Collection[Any]] = None,
+ op_kwargs: Optional[Mapping[str, Any]] = None,
string_args: Optional[Iterable[str]] = None,
templates_dict: Optional[Dict] = None,
templates_exts: Optional[List[str]] = None,
@@ -392,6 +395,9 @@ class PythonVirtualenvOperator(PythonOperator):
serializable_context = context.copy_only(serializable_keys)
return super().execute(context=serializable_context)
+ def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
+ return KeywordParameters.determine(self.python_callable, self.op_args, context).serializing()
+
def execute_callable(self):
with TemporaryDirectory(prefix='venv') as tmp_dir:
if self.templates_dict:
diff --git a/airflow/operators/weekday.py b/airflow/operators/weekday.py
index e1167a5..2e4e656 100644
--- a/airflow/operators/weekday.py
+++ b/airflow/operators/weekday.py
@@ -67,7 +67,7 @@ class BranchDayOfWeekOperator(BaseBranchOperator):
def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]:
if self.use_task_execution_day:
- now = context["execution_date"]
+ now = context["logical_date"]
else:
now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)
diff --git a/airflow/providers/http/operators/http.py b/airflow/providers/http/operators/http.py
index b629518..d36ceb2 100644
--- a/airflow/providers/http/operators/http.py
+++ b/airflow/providers/http/operators/http.py
@@ -104,7 +104,7 @@ class SimpleHttpOperator(BaseOperator):
raise AirflowException("'xcom_push' was deprecated, use 'BaseOperator.do_xcom_push' instead")
def execute(self, context: Dict[str, Any]) -> Any:
- from airflow.utils.operator_helpers import make_kwargs_callable
+ from airflow.utils.operator_helpers import determine_kwargs
http = HttpHook(self.method, http_conn_id=self.http_conn_id, auth_type=self.auth_type)
@@ -114,10 +114,10 @@ class SimpleHttpOperator(BaseOperator):
if self.log_response:
self.log.info(response.text)
if self.response_check:
- kwargs_callable = make_kwargs_callable(self.response_check)
- if not kwargs_callable(response, **context):
+ kwargs = determine_kwargs(self.response_check, [response], context)
+ if not self.response_check(response, **kwargs):
raise AirflowException("Response check returned False.")
if self.response_filter:
- kwargs_callable = make_kwargs_callable(self.response_filter)
- return kwargs_callable(response, **context)
+ kwargs = determine_kwargs(self.response_filter, [response], context)
+ return self.response_filter(response, **kwargs)
return response.text
diff --git a/airflow/providers/http/sensors/http.py b/airflow/providers/http/sensors/http.py
index 6ef55ea..e052c01 100644
--- a/airflow/providers/http/sensors/http.py
+++ b/airflow/providers/http/sensors/http.py
@@ -96,7 +96,7 @@ class HttpSensor(BaseSensorOperator):
self.hook = HttpHook(method=method, http_conn_id=http_conn_id)
def poke(self, context: Dict[Any, Any]) -> bool:
- from airflow.utils.operator_helpers import make_kwargs_callable
+ from airflow.utils.operator_helpers import determine_kwargs
self.log.info('Poking: %s', self.endpoint)
try:
@@ -107,9 +107,8 @@ class HttpSensor(BaseSensorOperator):
extra_options=self.extra_options,
)
if self.response_check:
- kwargs_callable = make_kwargs_callable(self.response_check)
- return kwargs_callable(response, **context)
-
+ kwargs = determine_kwargs(self.response_check, [response], context)
+ return self.response_check(response, **kwargs)
except AirflowException as exc:
if str(exc).startswith("404"):
return False
diff --git a/airflow/sensors/external_task.py b/airflow/sensors/external_task.py
index c451001..32336d3 100644
--- a/airflow/sensors/external_task.py
+++ b/airflow/sensors/external_task.py
@@ -47,7 +47,7 @@ class ExternalTaskSensorLink(BaseOperatorLink):
class ExternalTaskSensor(BaseSensorOperator):
"""
Waits for a different DAG or a task in a different DAG to complete for a
- specific execution_date
+ specific logical date.
:param external_dag_id: The dag_id that contains the task you want to
wait for
@@ -65,14 +65,14 @@ class ExternalTaskSensor(BaseSensorOperator):
:param failed_states: Iterable of failed or dis-allowed states, default is ``None``
:type failed_states: Iterable
:param execution_delta: time difference with the previous execution to
- look at, the default is the same execution_date as the current task or DAG.
+ look at, the default is the same logical date as the current task or DAG.
For yesterday, use [positive!] datetime.timedelta(days=1). Either
execution_delta or execution_date_fn can be passed to
ExternalTaskSensor, but not both.
:type execution_delta: Optional[datetime.timedelta]
- :param execution_date_fn: function that receives the current execution date as the first
+ :param execution_date_fn: function that receives the current execution's logical date as the first
positional argument and optionally any number of keyword arguments available in the
- context dictionary, and returns the desired execution dates to query.
+ context dictionary, and returns the desired logical dates to query.
Either execution_delta or execution_date_fn can be passed to ExternalTaskSensor,
but not both.
:type execution_date_fn: Optional[Callable]
@@ -157,11 +157,11 @@ class ExternalTaskSensor(BaseSensorOperator):
@provide_session
def poke(self, context, session=None):
if self.execution_delta:
- dttm = context['execution_date'] - self.execution_delta
+ dttm = context['logical_date'] - self.execution_delta
elif self.execution_date_fn:
dttm = self._handle_execution_date_fn(context=context)
else:
- dttm = context['execution_date']
+ dttm = context['logical_date']
dttm_filter = dttm if isinstance(dttm, list) else [dttm]
serialized_dttm_filter = ','.join(dt.isoformat() for dt in dttm_filter)
@@ -260,14 +260,14 @@ class ExternalTaskSensor(BaseSensorOperator):
"""
from airflow.utils.operator_helpers import make_kwargs_callable
- # Remove "execution_date" because it is already a mandatory positional argument
- execution_date = context["execution_date"]
- kwargs = {k: v for k, v in context.items() if k != "execution_date"}
+ # Remove "logical_date" because it is already a mandatory positional argument
+ logical_date = context["logical_date"]
+ kwargs = {k: v for k, v in context.items() if k not in {"execution_date", "logical_date"}}
# Add "context" in the kwargs for backward compatibility (because context used to be
# an acceptable argument of execution_date_fn)
kwargs["context"] = context
kwargs_callable = make_kwargs_callable(self.execution_date_fn)
- return kwargs_callable(execution_date, **kwargs)
+ return kwargs_callable(logical_date, **kwargs)
class ExternalTaskMarker(DummyOperator):
@@ -281,7 +281,7 @@ class ExternalTaskMarker(DummyOperator):
:type external_dag_id: str
:param external_task_id: The task_id of the dependent task that needs to be cleared.
:type external_task_id: str
- :param execution_date: The execution_date of the dependent task that needs to be cleared.
+ :param execution_date: The logical date of the dependent task execution that needs to be cleared.
:type execution_date: str or datetime.datetime
:param recursion_depth: The maximum level of transitive dependencies allowed. Default is 10.
This is mostly used for preventing cyclic dependencies. It is fine to increase
@@ -300,7 +300,7 @@ class ExternalTaskMarker(DummyOperator):
*,
external_dag_id: str,
external_task_id: str,
- execution_date: Optional[Union[str, datetime.datetime]] = "{{ execution_date.isoformat() }}",
+ execution_date: Optional[Union[str, datetime.datetime]] = "{{ logical_date.isoformat() }}",
recursion_depth: int = 10,
**kwargs,
):
diff --git a/airflow/sensors/weekday.py b/airflow/sensors/weekday.py
index 03e3221..741e166 100644
--- a/airflow/sensors/weekday.py
+++ b/airflow/sensors/weekday.py
@@ -84,6 +84,6 @@ class DayOfWeekSensor(BaseSensorOperator):
WeekDay(timezone.utcnow().isoweekday()).name,
)
if self.use_task_execution_day:
- return context['execution_date'].isoweekday() in self._week_day_num
+ return context['logical_date'].isoweekday() in self._week_day_num
else:
return timezone.utcnow().isoweekday() in self._week_day_num
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 61f9319..d8eee04 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -20,6 +20,7 @@
import contextlib
import copy
+import functools
import warnings
from typing import (
AbstractSet,
@@ -28,12 +29,15 @@ from typing import (
Dict,
Iterator,
List,
+ Mapping,
MutableMapping,
Optional,
Tuple,
ValuesView,
)
+import lazy_object_proxy
+
_NOT_SET: Any = object()
@@ -194,3 +198,32 @@ class Context(MutableMapping[str, Any]):
new = type(self)({k: v for k, v in self._context.items() if k in keys})
new._deprecation_replacements = self._deprecation_replacements.copy()
return new
+
+
+def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]:
+ """Create a mapping that wraps deprecated entries in a lazy object proxy.
+
+ This further delays deprecation warning to until when the entry is actually
+ used, instead of when it's accessed in the context. The result is useful for
+ passing into a callable with ``**kwargs``, which would unpack the mapping
+ too eagerly otherwise.
+
+ This is implemented as a free function because the ``Context`` type is
+ "faked" as a ``TypedDict`` in ``context.pyi``, which cannot have custom
+ functions.
+
+ :meta private:
+ """
+
+ def _deprecated_proxy_factory(k: str, v: Any) -> Any:
+ replacements = source._deprecation_replacements[k]
+ warnings.warn(_create_deprecation_warning(k, replacements))
+ return v
+
+ def _create_value(k: str, v: Any) -> Any:
+ if k not in source._deprecation_replacements:
+ return v
+ factory = functools.partial(_deprecated_proxy_factory, k, v)
+ return lazy_object_proxy.Proxy(factory)
+
+ return {k: _create_value(k, v) for k, v in source._context.items()}
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index 0921d79..44b152c 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -25,7 +25,7 @@
# undefined attribute errors from Mypy. Hopefully there will be a mechanism to
# declare "these are defined, but don't error if others are accessed" someday.
-from typing import Any, Optional
+from typing import Any, Mapping, Optional
from pendulum import DateTime
@@ -80,3 +80,7 @@ class Context(TypedDict, total=False):
var: _VariableAccessors
yesterday_ds: str
yesterday_ds_nodash: str
+
+class AirflowContextDeprecationWarning(DeprecationWarning): ...
+
+def lazy_mapping_from_context(source: Context) -> Mapping[str, Any]: ...
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index c5f9f27..2215c4c 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -167,7 +167,7 @@ def render_log_filename(ti: "TaskInstance", try_number, filename_template) -> st
if filename_jinja_template:
jinja_context = ti.get_template_context()
jinja_context['try_number'] = try_number
- return filename_jinja_template.render(**jinja_context)
+ return render_template_to_string(filename_jinja_template, jinja_context)
return filename_template.format(
dag_id=ti.dag_id,
diff --git a/airflow/utils/log/task_handler_with_custom_formatter.py b/airflow/utils/log/task_handler_with_custom_formatter.py
index 5034d00..b7b431b 100644
--- a/airflow/utils/log/task_handler_with_custom_formatter.py
+++ b/airflow/utils/log/task_handler_with_custom_formatter.py
@@ -20,7 +20,7 @@ import logging
from logging import StreamHandler
from airflow.configuration import conf
-from airflow.utils.helpers import parse_template_string
+from airflow.utils.helpers import parse_template_string, render_template_to_string
class TaskHandlerWithCustomFormatter(StreamHandler):
@@ -52,6 +52,6 @@ class TaskHandlerWithCustomFormatter(StreamHandler):
def _render_prefix(self, ti):
if self.prefix_jinja_template:
jinja_context = ti.get_template_context()
- return self.prefix_jinja_template.render(**jinja_context)
+ return render_template_to_string(self.prefix_jinja_template, jinja_context)
logging.warning("'task_log_prefix_template' is in invalid format, ignoring the variable value")
return ""
diff --git a/airflow/utils/operator_helpers.py b/airflow/utils/operator_helpers.py
index 8c5125b..05c050c 100644
--- a/airflow/utils/operator_helpers.py
+++ b/airflow/utils/operator_helpers.py
@@ -17,7 +17,9 @@
# under the License.
#
from datetime import datetime
-from typing import Callable, Dict, List, Mapping, Tuple, Union
+from typing import Any, Callable, Collection, Mapping
+
+from airflow.utils.context import Context, lazy_mapping_from_context
AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
'AIRFLOW_CONTEXT_DAG_ID': {'default': 'airflow.ctx.dag_id', 'env_var_format': 'AIRFLOW_CTX_DAG_ID'},
@@ -88,7 +90,67 @@ def context_to_airflow_vars(context, in_env_var_format=False):
return params
-def determine_kwargs(func: Callable, args: Union[Tuple, List], kwargs: Mapping) -> Dict:
+class KeywordParameters:
+ """Wrapper representing ``**kwargs`` to a callable.
+
+ The actual ``kwargs`` can be obtained by calling either ``unpacking()`` or
+ ``serializing()``. They behave almost the same and are only different if
+ the containing ``kwargs`` is an Airflow Context object, and the calling
+ function uses ``**kwargs`` in the argument list.
+
+ In this particular case, ``unpacking()`` uses ``lazy-object-proxy`` to
+ prevent the Context from emitting deprecation warnings too eagerly when it's
+ unpacked by ``**``. ``serializing()`` does not do this, and will allow the
+ warnings to be emitted eagerly, which is useful when you want to dump the
+ content and use it somewhere else without needing ``lazy-object-proxy``.
+ """
+
+ def __init__(self, kwargs: Mapping[str, Any], *, wildcard: bool) -> None:
+ self._kwargs = kwargs
+ self._wildcard = wildcard
+
+ @classmethod
+ def determine(
+ cls,
+ func: Callable[..., Any],
+ args: Collection[Any],
+ kwargs: Mapping[str, Any],
+ ) -> "KeywordParameters":
+ import inspect
+ import itertools
+
+ signature = inspect.signature(func)
+ has_wildcard_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values())
+
+ for name in itertools.islice(signature.parameters.keys(), len(args)):
+ # Check if args conflict with names in kwargs.
+ if name in kwargs:
+ raise ValueError(f"The key {name!r} in args is a part of kwargs and therefore reserved.")
+
+ if has_wildcard_kwargs:
+ # If the callable has a **kwargs argument, it's ready to accept all the kwargs.
+ return cls(kwargs, wildcard=True)
+
+ # If the callable has no **kwargs argument, it only wants the arguments it requested.
+ kwargs = {key: kwargs[key] for key in signature.parameters if key in kwargs}
+ return cls(kwargs, wildcard=False)
+
+ def unpacking(self) -> Mapping[str, Any]:
+ """Dump the kwargs mapping to unpack with ``**`` in a function call."""
+ if self._wildcard and isinstance(self._kwargs, Context):
+ return lazy_mapping_from_context(self._kwargs)
+ return self._kwargs
+
+ def serializing(self) -> Mapping[str, Any]:
+ """Dump the kwargs mapping for serialization purposes."""
+ return self._kwargs
+
+
+def determine_kwargs(
+ func: Callable[..., Any],
+ args: Collection[Any],
+ kwargs: Mapping[str, Any],
+) -> Mapping[str, Any]:
"""
Inspect the signature of a given callable to determine which arguments in kwargs need
to be passed to the callable.
@@ -99,23 +161,7 @@ def determine_kwargs(func: Callable, args: Union[Tuple, List], kwargs: Mapping)
:param kwargs: The keyword arguments that need to be filtered before passing to the callable.
:return: A dictionary which contains the keyword arguments that are compatible with the callable.
"""
- import inspect
- import itertools
-
- signature = inspect.signature(func)
- has_kwargs = any(p.kind == p.VAR_KEYWORD for p in signature.parameters.values())
-
- for name in itertools.islice(signature.parameters.keys(), len(args)):
- # Check if args conflict with names in kwargs
- if name in kwargs:
- raise ValueError(f"The key {name} in args is part of kwargs and therefore reserved.")
-
- if has_kwargs:
- # If the callable has a **kwargs argument, it's ready to accept all the kwargs.
- return kwargs
-
- # If the callable has no **kwargs argument, it only wants the arguments it requested.
- return {key: kwargs[key] for key in signature.parameters if key in kwargs}
+ return KeywordParameters.determine(func, args, kwargs).unpacking()
def make_kwargs_callable(func: Callable) -> Callable:
diff --git a/scripts/ci/kubernetes/ci_run_kubernetes_tests.sh b/scripts/ci/kubernetes/ci_run_kubernetes_tests.sh
index a97f692..e586c30 100755
--- a/scripts/ci/kubernetes/ci_run_kubernetes_tests.sh
+++ b/scripts/ci/kubernetes/ci_run_kubernetes_tests.sh
@@ -52,10 +52,7 @@ function parse_tests_to_run() {
else
tests_to_run=("${@}")
fi
- pytest_args=(
- "--pythonwarnings=ignore::DeprecationWarning"
- "--pythonwarnings=ignore::PendingDeprecationWarning"
- )
+ pytest_args=()
else
tests_to_run=("kubernetes_tests")
pytest_args=(
@@ -64,8 +61,6 @@ function parse_tests_to_run() {
"--durations=100"
"--color=yes"
"--maxfail=50"
- "--pythonwarnings=ignore::DeprecationWarning"
- "--pythonwarnings=ignore::PendingDeprecationWarning"
)
fi
diff --git a/scripts/in_container/entrypoint_ci.sh b/scripts/in_container/entrypoint_ci.sh
index 29f5210..5d7aca0 100755
--- a/scripts/in_container/entrypoint_ci.sh
+++ b/scripts/in_container/entrypoint_ci.sh
@@ -209,8 +209,6 @@ EXTRA_PYTEST_ARGS=(
"--cov-report=xml:/files/coverage-${TEST_TYPE}-${BACKEND}.xml"
"--color=yes"
"--maxfail=50"
- "--pythonwarnings=ignore::DeprecationWarning"
- "--pythonwarnings=ignore::PendingDeprecationWarning"
"--junitxml=${RESULT_LOG_FILE}"
# timeouts in seconds for individual tests
"--timeouts-order"
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index 7d246c7..201af16 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -84,6 +84,7 @@ class TestCliTasks(unittest.TestCase):
args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree'])
task_command.task_list(args)
+ @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_test(self):
"""Test the `airflow test` command"""
args = self.parser.parse_args(
@@ -96,6 +97,7 @@ class TestCliTasks(unittest.TestCase):
# Check that prints, and log messages, are shown
assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()
+ @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_test_with_existing_dag_run(self):
"""Test the `airflow test` command"""
task_id = 'print_the_context'
diff --git a/tests/core/test_core.py b/tests/core/test_core.py
index cae311d..02162e9 100644
--- a/tests/core/test_core.py
+++ b/tests/core/test_core.py
@@ -218,7 +218,7 @@ class TestCore:
op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
def test_python_op(self, dag_maker):
- def test_py_op(templates_dict, ds, **kwargs):
+ def test_py_op(templates_dict, ds):
if not templates_dict['ds'] == ds:
raise Exception("failure")
@@ -246,10 +246,6 @@ class TestCore:
assert context['ds'] == '2015-01-01'
assert context['ds_nodash'] == '20150101'
- # next_ds is 2015-01-02 as the dag schedule is daily.
- assert context['next_ds'] == '2015-01-02'
- assert context['next_ds_nodash'] == '20150102'
-
assert context['ts'] == '2015-01-01T00:00:00+00:00'
assert context['ts_nodash'] == '20150101T000000'
assert context['ts_nodash_with_tz'] == '20150101T000000+0000'
@@ -259,6 +255,8 @@ class TestCore:
# Test deprecated fields.
expected_deprecated_fields = [
+ ("next_ds", "2015-01-02"),
+ ("next_ds_nodash", "20150102"),
("prev_ds", "2014-12-31"),
("prev_ds_nodash", "20141231"),
("yesterday_ds", "2014-12-31"),
@@ -267,14 +265,17 @@ class TestCore:
("tomorrow_ds_nodash", "20150102"),
]
for key, expected_value in expected_deprecated_fields:
- message = (
+ message_beginning = (
f"Accessing {key!r} from the template is deprecated and "
f"will be removed in a future version."
)
with pytest.deprecated_call() as recorder:
value = str(context[key]) # Simulate template evaluation to trigger warning.
assert value == expected_value
- assert [str(m.message) for m in recorder] == [message]
+
+ recorded_message = [str(m.message) for m in recorder]
+ assert len(recorded_message) == 1
+ assert recorded_message[0].startswith(message_beginning)
def test_bad_trigger_rule(self, dag_maker):
with pytest.raises(AirflowException):
@@ -338,8 +339,10 @@ class TestCore:
context = ti.get_template_context()
# next_ds should be the execution date for manually triggered runs
- assert context['next_ds'] == execution_ds
- assert context['next_ds_nodash'] == execution_ds_nodash
+ with pytest.deprecated_call():
+ assert context['next_ds'] == execution_ds
+ with pytest.deprecated_call():
+ assert context['next_ds_nodash'] == execution_ds_nodash
def test_dag_params_and_task_params(self, dag_maker):
# This test case guards how params of DAG and Operator work together.
diff --git a/tests/operators/test_email.py b/tests/operators/test_email.py
index 5419796..ba2acda 100644
--- a/tests/operators/test_email.py
+++ b/tests/operators/test_email.py
@@ -50,7 +50,7 @@ class TestEmailOperator(unittest.TestCase):
html_content='The quick brown fox jumps over the lazy dog',
task_id='task',
dag=self.dag,
- files=["/tmp/Report-A-{{ execution_date.strftime('%Y-%m-%d') }}.csv"],
+ files=["/tmp/Report-A-{{ ds }}.csv"],
**kwargs,
)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
diff --git a/tests/operators/test_python.py b/tests/operators/test_python.py
index 172468b..ac34468 100644
--- a/tests/operators/test_python.py
+++ b/tests/operators/test_python.py
@@ -19,6 +19,7 @@ import copy
import logging
import sys
import unittest.mock
+import warnings
from collections import namedtuple
from datetime import date, datetime, timedelta
from subprocess import CalledProcessError
@@ -39,6 +40,7 @@ from airflow.operators.python import (
get_current_context,
)
from airflow.utils import timezone
+from airflow.utils.context import AirflowContextDeprecationWarning
from airflow.utils.dates import days_ago
from airflow.utils.session import create_session
from airflow.utils.state import State
@@ -850,6 +852,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
# This tests might take longer than default 60 seconds as it is serializing a lot of
# context using dill (which is slow apparently).
@pytest.mark.execution_timeout(120)
+ @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_airflow_context(self):
def f(
# basic
@@ -890,6 +893,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
self._run_as_operator(f, use_dill=True, system_site_packages=True, requirements=None)
+ @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_pendulum_context(self):
def f(
# basic
@@ -923,6 +927,7 @@ class TestPythonVirtualenvOperator(unittest.TestCase):
self._run_as_operator(f, use_dill=True, system_site_packages=False, requirements=['pendulum'])
+ @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_base_context(self):
def f(
# basic
@@ -1026,7 +1031,9 @@ class MyContextAssertOperator(BaseOperator):
def get_all_the_context(**context):
current_context = get_current_context()
- assert context == current_context._context
+ with warnings.catch_warnings():
+ warnings.simplefilter("ignore", AirflowContextDeprecationWarning)
+ assert context == current_context._context
@pytest.fixture()
diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py
index ea61687..9ff8735 100644
--- a/tests/operators/test_trigger_dagrun.py
+++ b/tests/operators/test_trigger_dagrun.py
@@ -152,7 +152,7 @@ class TestDagRunOperator(TestCase):
task = TriggerDagRunOperator(
task_id="test_trigger_dagrun_with_str_execution_date",
trigger_dag_id=TRIGGERED_DAG_ID,
- execution_date="{{ execution_date }}",
+ execution_date="{{ logical_date }}",
dag=self.dag,
)
task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
diff --git a/tests/providers/http/sensors/test_http.py b/tests/providers/http/sensors/test_http.py
index 3fc61bb..dc3b41f 100644
--- a/tests/providers/http/sensors/test_http.py
+++ b/tests/providers/http/sensors/test_http.py
@@ -125,8 +125,8 @@ class TestHttpSensor:
response.status_code = 200
mock_session_send.return_value = response
- def resp_check(_, execution_date):
- if execution_date == DEFAULT_DATE:
+ def resp_check(_, logical_date):
+ if logical_date == DEFAULT_DATE:
return True
raise AirflowException('AirflowException raised here!')
diff --git a/tests/sensors/test_external_task_sensor.py b/tests/sensors/test_external_task_sensor.py
index d1e150b..28018b9 100644
--- a/tests/sensors/test_external_task_sensor.py
+++ b/tests/sensors/test_external_task_sensor.py
@@ -174,7 +174,7 @@ class TestExternalTaskSensor(unittest.TestCase):
def test_external_task_sensor_fn_multiple_execution_dates(self):
bash_command_code = """
-{% set s=execution_date.time().second %}
+{% set s=logical_date.time().second %}
echo "second is {{ s }}"
if [[ $(( {{ s }} % 60 )) == 1 ]]
then
@@ -292,7 +292,7 @@ exit 0
self.test_time_sensor()
def my_func(dt, context):
- assert context['execution_date'] == dt
+ assert context['logical_date'] == dt
return dt + timedelta(0)
op1 = ExternalTaskSensor(
@@ -541,7 +541,7 @@ def dag_bag_parent_child():
task_id="task_1",
external_dag_id=dag_0.dag_id,
external_task_id=task_0.task_id,
- execution_date_fn=lambda execution_date: day_1 if execution_date == day_1 else [],
+ execution_date_fn=lambda logical_date: day_1 if logical_date == day_1 else [],
mode='reschedule',
)
@@ -884,7 +884,7 @@ def dag_bag_head_tail():
task_id="tail",
external_dag_id=dag.dag_id,
external_task_id=head.task_id,
- execution_date="{{ tomorrow_ds_nodash }}",
+ execution_date="{{ macros.ds_add(ds, 1) }}",
)
head >> body >> tail
diff --git a/tests/utils/test_log_handlers.py b/tests/utils/test_log_handlers.py
index 4503dd8..78166a8 100644
--- a/tests/utils/test_log_handlers.py
+++ b/tests/utils/test_log_handlers.py
@@ -62,7 +62,7 @@ class TestFileTaskLogHandler:
assert handler.name == FILE_TASK_HANDLER
def test_file_task_handler_when_ti_value_is_invalid(self):
- def task_callable(ti, **kwargs):
+ def task_callable(ti):
ti.log.info("test")
dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
@@ -114,7 +114,7 @@ class TestFileTaskLogHandler:
os.remove(log_filename)
def test_file_task_handler(self):
- def task_callable(ti, **kwargs):
+ def task_callable(ti):
ti.log.info("test")
dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)
@@ -168,7 +168,7 @@ class TestFileTaskLogHandler:
os.remove(log_filename)
def test_file_task_handler_running(self):
- def task_callable(ti, **kwargs):
+ def task_callable(ti):
ti.log.info("test")
dag = DAG('dag_for_testing_file_task_handler', start_date=DEFAULT_DATE)