You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/02/24 08:39:50 UTC

[airflow] branch main updated: Rewrite taskflow-mapping argument validation (#21759)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2da207d  Rewrite taskflow-mapping argument validation (#21759)
2da207d is described below

commit 2da207d4211ea5070c4a79756c93c766abe7b397
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Thu Feb 24 16:38:59 2022 +0800

    Rewrite taskflow-mapping argument validation (#21759)
---
 airflow/decorators/base.py       | 76 ++++++++++++++++++++++------------------
 airflow/models/mappedoperator.py |  2 +-
 airflow/models/taskinstance.py   |  3 ++
 airflow/utils/context.py         | 45 ++++++++++++++++++++++++
 airflow/utils/context.pyi        |  5 ++-
 tests/decorators/test_python.py  | 61 +++++++++++++++++++++++++-------
 6 files changed, 144 insertions(+), 48 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 46e4c0a..38354d9 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -19,7 +19,6 @@ import collections.abc
 import functools
 import inspect
 import re
-import sys
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -52,12 +51,13 @@ from airflow.models.mappedoperator import (
     ValidationSource,
     create_mocked_kwargs,
     get_mappable_types,
+    prevent_duplicates,
 )
 from airflow.models.pool import Pool
 from airflow.models.xcom_arg import XComArg
 from airflow.typing_compat import Protocol
 from airflow.utils import timezone
-from airflow.utils.context import Context
+from airflow.utils.context import KNOWN_CONTEXT_KEYS, Context
 from airflow.utils.task_group import TaskGroup, TaskGroupContext
 from airflow.utils.types import NOTSET
 
@@ -227,45 +227,25 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
     :meta private:
     """
 
-    function: Function = attr.ib(validator=attr.validators.is_callable())
+    function: Function = attr.ib()
     operator_class: Type[OperatorSubclass]
     multiple_outputs: bool = attr.ib()
     kwargs: Dict[str, Any] = attr.ib(factory=dict)
 
     decorator_name: str = attr.ib(repr=False, default="task")
 
-    @cached_property
-    def function_signature(self):
-        return inspect.signature(self.function)
-
-    @cached_property
-    def function_arg_names(self) -> Set[str]:
-        return set(self.function_signature.parameters)
-
-    @function.validator
-    def _validate_function(self, _, f):
-        if 'self' in self.function_arg_names:
-            raise TypeError(f'@{self.decorator_name} does not support methods')
-
     @multiple_outputs.default
     def _infer_multiple_outputs(self):
         try:
             return_type = typing_extensions.get_type_hints(self.function).get("return", Any)
         except Exception:  # Can't evaluate retrurn type.
             return False
-
-        # Get the non-subscripted type. The ``__origin__`` attribute is not
-        # stable until 3.7, but we need to use ``__extra__`` instead.
-        # TODO: Remove the ``__extra__`` branch when support for Python 3.6 is
-        # dropped in Airflow 2.3.
-        if sys.version_info < (3, 7):
-            ttype = getattr(return_type, "__extra__", return_type)
-        else:
-            ttype = getattr(return_type, "__origin__", return_type)
-
+        ttype = getattr(return_type, "__origin__", return_type)
         return ttype == dict or ttype == Dict
 
     def __attrs_post_init__(self):
+        if "self" in self.function_signature.parameters:
+            raise TypeError(f"@{self.decorator_name} does not support methods")
         self.kwargs.setdefault('task_id', self.function.__name__)
 
     def __call__(self, *args, **kwargs) -> XComArg:
@@ -280,22 +260,50 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
             op.doc_md = self.function.__doc__
         return XComArg(op)
 
+    @cached_property
+    def function_signature(self):
+        return inspect.signature(self.function)
+
+    @cached_property
+    def _function_is_vararg(self):
+        return any(
+            v.kind == inspect.Parameter.VAR_KEYWORD for v in self.function_signature.parameters.values()
+        )
+
+    @cached_property
+    def _mappable_function_argument_names(self) -> Set[str]:
+        """Arguments that can be mapped against."""
+        return set(self.function_signature.parameters)
+
     def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
+        # Ensure that context variables are not shadowed.
+        context_keys_being_mapped = KNOWN_CONTEXT_KEYS.intersection(kwargs)
+        if len(context_keys_being_mapped) == 1:
+            (name,) = context_keys_being_mapped
+            raise ValueError(f"cannot call {func}() on task context variable {name!r}")
+        elif context_keys_being_mapped:
+            names = ", ".join(repr(n) for n in context_keys_being_mapped)
+            raise ValueError(f"cannot call {func}() on task context variables {names}")
+
+        # Ensure that all arguments passed in are accounted for.
+        if self._function_is_vararg:
+            return
         kwargs_left = kwargs.copy()
-        for arg_name in self.function_arg_names:
+        for arg_name in self._mappable_function_argument_names:
             value = kwargs_left.pop(arg_name, NOTSET)
             if func != "map" or value is NOTSET or isinstance(value, get_mappable_types()):
                 continue
-            raise ValueError(f"{func} got unexpected value{type(value)!r} for keyword argument {arg_name!r}")
-
+            type_name = type(value).__name__
+            raise ValueError(f"map() got an unexpected type {type_name!r} for keyword argument {arg_name!r}")
         if len(kwargs_left) == 1:
-            raise TypeError(f"{func} got unexpected keyword argument {next(iter(kwargs_left))!r}")
+            raise TypeError(f"{func}() got an unexpected keyword argument {next(iter(kwargs_left))!r}")
         elif kwargs_left:
             names = ", ".join(repr(n) for n in kwargs_left)
-            raise TypeError(f"{func} got unexpected keyword arguments {names}")
+            raise TypeError(f"{func}() got unexpected keyword arguments {names}")
 
-    def map(self, **kwargs: "MapArgument") -> XComArg:
-        self._validate_arg_names("map", kwargs)
+    def map(self, **map_kwargs: "MapArgument") -> XComArg:
+        self._validate_arg_names("map", map_kwargs)
+        prevent_duplicates(self.kwargs, map_kwargs, fail_reason="mapping already partial")
 
         partial_kwargs = self.kwargs.copy()
 
@@ -345,7 +353,7 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
             end_date=end_date,
             multiple_outputs=self.multiple_outputs,
             python_callable=self.function,
-            mapped_op_kwargs=kwargs,
+            mapped_op_kwargs=map_kwargs,
         )
         return XComArg(operator=operator)
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 99b70fb..a0c1620 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -111,7 +111,7 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va
             if isinstance(value, get_mappable_types()):
                 continue
             type_name = type(value).__name__
-            error = f"{op.__name__}.map() got unexpected type {type_name!r} for keyword argument {name}"
+            error = f"{op.__name__}.map() got an unexpected type {type_name!r} for keyword argument {name}"
             raise ValueError(error)
         if not unknown_args:
             return  # If we have no args left ot check: stop looking at the MRO chian.
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 34f5a6e..954aaf6 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1929,6 +1929,9 @@ class TaskInstance(Base, LoggingMixin):
                 return None
             return prev_ds.replace('-', '')
 
+        # NOTE: If you add anything to this dict, make sure to also update the
+        # definition in airflow/utils/context.pyi, and KNOWN_CONTEXT_KEYS in
+        # airflow/utils/context.py!
         context = {
             'conf': conf,
             'dag': dag,
diff --git a/airflow/utils/context.py b/airflow/utils/context.py
index 2cd8708..04dabab 100644
--- a/airflow/utils/context.py
+++ b/airflow/utils/context.py
@@ -41,6 +41,51 @@ import lazy_object_proxy
 
 from airflow.utils.types import NOTSET
 
+# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi.
+KNOWN_CONTEXT_KEYS = {
+    "conf",
+    "conn",
+    "dag",
+    "dag_run",
+    "data_interval_end",
+    "data_interval_start",
+    "ds",
+    "ds_nodash",
+    "execution_date",
+    "exception",
+    "inlets",
+    "logical_date",
+    "macros",
+    "next_ds",
+    "next_ds_nodash",
+    "next_execution_date",
+    "outlets",
+    "params",
+    "prev_data_interval_start_success",
+    "prev_data_interval_end_success",
+    "prev_ds",
+    "prev_ds_nodash",
+    "prev_execution_date",
+    "prev_execution_date_success",
+    "prev_start_date_success",
+    "run_id",
+    "task",
+    "task_instance",
+    "task_instance_key_str",
+    "test_mode",
+    "templates_dict",
+    "ti",
+    "tomorrow_ds",
+    "tomorrow_ds_nodash",
+    "ts",
+    "ts_nodash",
+    "ts_nodash_with_tz",
+    "try_number",
+    "var",
+    "yesterday_ds",
+    "yesterday_ds_nodash",
+}
+
 
 class VariableAccessor:
     """Wrapper to access Variable values in template."""
diff --git a/airflow/utils/context.pyi b/airflow/utils/context.pyi
index f614459..6003d1d 100644
--- a/airflow/utils/context.pyi
+++ b/airflow/utils/context.pyi
@@ -25,7 +25,7 @@
 # undefined attribute errors from Mypy. Hopefully there will be a mechanism to
 # declare "these are defined, but don't error if others are accessed" someday.
 
-from typing import Any, Container, Iterable, Mapping, Optional, Tuple, Union, overload
+from typing import Any, Container, Iterable, Mapping, Optional, Set, Tuple, Union, overload
 
 from pendulum import DateTime
 
@@ -37,6 +37,8 @@ from airflow.models.param import ParamsDict
 from airflow.models.taskinstance import TaskInstance
 from airflow.typing_compat import TypedDict
 
+KNOWN_CONTEXT_KEYS: Set[str]
+
 class _VariableAccessors(TypedDict):
     json: Any
     value: Any
@@ -48,6 +50,7 @@ class VariableAccessor:
 class ConnectionAccessor:
     def get(self, key: str, default_conn: Any = None) -> Any: ...
 
+# NOTE: Please keep this in sync with KNOWN_CONTEXT_KEYS in airflow/utils/context.py.
 class Context(TypedDict):
     conf: AirflowConfigParser
     conn: Any
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 4f6d0bc..e127ab6 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -477,22 +477,59 @@ class TestAirflowTaskDecorator:
         assert ret.operator.doc_md.strip(), "Adds 2 to number."
 
 
-def test_mapped_decorator() -> None:
+def test_mapped_decorator_shadow_context() -> None:
     @task_decorator
-    def double(number: int):
-        return number * 2
+    def print_info(message: str, run_id: str = "") -> None:
+        print(f"{run_id}: {message}")
+
+    with pytest.raises(ValueError) as ctx:
+        print_info.partial(run_id="hi")
+    assert str(ctx.value) == "cannot call partial() on task context variable 'run_id'"
+
+    with pytest.raises(ValueError) as ctx:
+        print_info.map(run_id=["hi", "there"])
+    assert str(ctx.value) == "cannot call map() on task context variable 'run_id'"
+
+
+def test_mapped_decorator_wrong_argument() -> None:
+    @task_decorator
+    def print_info(message: str, run_id: str = "") -> None:
+        print(f"{run_id}: {message}")
+
+    with pytest.raises(TypeError) as ct:
+        print_info.partial(wrong_name="hi")
+    assert str(ct.value) == "partial() got an unexpected keyword argument 'wrong_name'"
+
+    with pytest.raises(TypeError) as ct:
+        print_info.map(wrong_name=["hi", "there"])
+    assert str(ct.value) == "map() got an unexpected keyword argument 'wrong_name'"
+
+    with pytest.raises(ValueError) as cv:
+        print_info.map(message="hi")
+    assert str(cv.value) == "map() got an unexpected type 'str' for keyword argument 'message'"
+
+
+def test_mapped_decorator():
+    @task_decorator
+    def print_info(m1: str, m2: str, run_id: str = "") -> None:
+        print(f"{run_id}: {m1} {m2}")
+
+    @task_decorator
+    def print_everything(**kwargs) -> None:
+        print(kwargs)
 
-    with DAG('test_dag', start_date=DEFAULT_DATE):
-        literal = [1, 2, 3]
-        doubled_0 = double.map(number=literal)
-        doubled_1 = double.map(number=literal)
+    with DAG("test_mapped_decorator", start_date=DEFAULT_DATE):
+        t0 = print_info.map(m1=["a", "b"], m2={"foo": "bar"})
+        t1 = print_info.partial(m1="hi").map(m2=[1, 2, 3])
+        t2 = print_everything.partial(whatever="123").map(any_key=[1, 2], works=t1)
 
-    assert isinstance(doubled_0, XComArg)
-    assert isinstance(doubled_0.operator, DecoratedMappedOperator)
-    assert doubled_0.operator.task_id == "double"
-    assert doubled_0.operator.mapped_op_kwargs == {"number": literal}
+    assert isinstance(t2, XComArg)
+    assert isinstance(t2.operator, DecoratedMappedOperator)
+    assert t2.operator.task_id == "print_everything"
+    assert t2.operator.mapped_op_kwargs == {"any_key": [1, 2], "works": t1}
 
-    assert doubled_1.operator.task_id == "double__1"
+    assert t0.operator.task_id == "print_info"
+    assert t1.operator.task_id == "print_info__1"
 
 
 def test_mapped_decorator_invalid_args() -> None: