You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/29 08:58:28 UTC

[airflow] branch main updated: Check expand_kwargs() input type before unmapping (#25355)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4e786e31bc Check expand_kwargs() input type before unmapping (#25355)
4e786e31bc is described below

commit 4e786e31bcdf81427163918e14d191e55a4ab606
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Fri Jul 29 16:58:18 2022 +0800

    Check expand_kwargs() input type before unmapping (#25355)
---
 airflow/decorators/base.py        |  2 +-
 airflow/models/expandinput.py     | 27 ++++++++++++++++++++++-----
 airflow/models/mappedoperator.py  |  9 +++++----
 tests/models/test_taskinstance.py | 36 ++++++++++++++++++++++++++++++++++++
 tests/models/test_xcom_arg_map.py |  4 ++--
 5 files changed, 66 insertions(+), 12 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 0a4b75cece..49b8f055dc 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -455,7 +455,7 @@ class DecoratedMappedOperator(MappedOperator):
         assert self.expand_input is EXPAND_INPUT_EMPTY
         return {"op_kwargs": super()._expand_mapped_kwargs(resolve)}
 
-    def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
+    def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> Dict[str, Any]:
         if strict:
             prevent_duplicates(
                 self.partial_kwargs["op_kwargs"],
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index f3fa84a302..5a94698f9b 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -22,7 +22,7 @@ import collections
 import collections.abc
 import functools
 import operator
-from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Sequence, Sized, Union
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
 
 from sqlalchemy import func
 from sqlalchemy.orm import Session
@@ -195,10 +195,16 @@ class DictOfListsExpandInput(NamedTuple):
                 return k, v
         raise IndexError(f"index {map_index} is over mapped length")
 
-    def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+    def resolve(self, context: Context, session: Session) -> Mapping[str, Any]:
         return {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
 
 
+def _describe_type(value: Any) -> str:
+    if value is None:
+        return "None"
+    return type(value).__name__
+
+
 class ListOfDictsExpandInput(NamedTuple):
     """Storage type of a mapped operator's mapped kwargs.
 
@@ -245,12 +251,23 @@ class ListOfDictsExpandInput(NamedTuple):
             raise NotFullyPopulated({"expand_kwargs() argument"})
         return value
 
-    def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+    def resolve(self, context: Context, session: Session) -> Mapping[str, Any]:
         map_index = context["ti"].map_index
         if map_index < 0:
             raise RuntimeError("can't resolve task-mapping argument without expanding")
-        # Validation should be done when the upstream returns.
-        return self.value.resolve(context, session)[map_index]
+        mappings = self.value.resolve(context, session)
+        if not isinstance(mappings, collections.abc.Sequence):
+            raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
+        mapping = mappings[map_index]
+        if not isinstance(mapping, collections.abc.Mapping):
+            raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
+        for key in mapping:
+            if not isinstance(key, str):
+                raise ValueError(
+                    f"expand_kwargs() input dict keys must all be str, "
+                    f"but {key!r} is of type {_describe_type(key)}"
+                )
+        return mapping
 
 
 EXPAND_INPUT_EMPTY = DictOfListsExpandInput({})  # Sentinel value.
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index dead62753c..53ea072a79 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -30,6 +30,7 @@ from typing import (
     Iterable,
     Iterator,
     List,
+    Mapping,
     Optional,
     Sequence,
     Set,
@@ -123,7 +124,7 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va
     raise TypeError(f"{op.__name__}.{func}() got {error}")
 
 
-def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail_reason: str) -> None:
+def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None:
     duplicated_keys = set(kwargs1).intersection(kwargs2)
     if not duplicated_keys:
         return
@@ -528,7 +529,7 @@ class MappedOperator(AbstractOperator):
         """Implementing DAGNode."""
         return DagAttributeTypes.OP, self.task_id
 
-    def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, Session]]) -> Dict[str, Any]:
+    def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, Session]]) -> Mapping[str, Any]:
         """Get the kwargs to create the unmapped operator.
 
         If *resolve* is not *None*, it must be a two-tuple to provide context to
@@ -546,7 +547,7 @@ class MappedOperator(AbstractOperator):
             return expand_input.resolve(*resolve)
         return expand_input.get_unresolved_kwargs()
 
-    def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
+    def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> Dict[str, Any]:
         """Get init kwargs to unmap the underlying operator class.
 
         :param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
@@ -569,7 +570,7 @@ class MappedOperator(AbstractOperator):
             **mapped_kwargs,
         }
 
-    def unmap(self, resolve: Union[None, Dict[str, Any], Tuple[Context, Session]]) -> "BaseOperator":
+    def unmap(self, resolve: Union[None, Mapping[str, Any], Tuple[Context, Session]]) -> "BaseOperator":
         """Get the "normal" Operator after applying the current mapping.
 
         If ``operator_class`` is not a class (i.e. this DAG has been
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 21badfe8cb..890e30f541 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -34,6 +34,7 @@ import pytest
 from freezegun import freeze_time
 
 from airflow import models, settings
+from airflow.decorators import task
 from airflow.example_dags.plugins.workday import AfterWorkdayTimetable
 from airflow.exceptions import (
     AirflowException,
@@ -2640,6 +2641,41 @@ class TestTaskInstanceRecordTaskMapXComPush:
         assert ti.state == TaskInstanceState.FAILED
         assert str(ctx.value) == error_message
 
+    @pytest.mark.parametrize(
+        "create_upstream",
+        [
+            # The task returns an invalid expand_kwargs() input (a list[int] instead of list[dict]).
+            pytest.param(lambda: task(task_id="push")(lambda: [0])(), id="normal"),
+            # This task returns a list[dict] (correct), but we use map() to transform it to list[int] (wrong).
+            pytest.param(lambda: task(task_id="push")(lambda: [{"v": ""}])().map(lambda _: 0), id="mapped"),
+        ],
+    )
+    def test_expand_kwargs_error_if_received_invalid(self, dag_maker, session, create_upstream):
+        with dag_maker(dag_id="test_expand_kwargs_error_if_received_invalid", session=session):
+            push_task = create_upstream()
+
+            @task()
+            def pull(v):
+                print(v)
+
+            pull.expand_kwargs(push_task)
+
+        dr = dag_maker.create_dagrun()
+
+        # Run "push".
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        assert decision.schedulable_tis
+        for ti in decision.schedulable_tis:
+            ti.run()
+
+        # Run "pull".
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        assert decision.schedulable_tis
+        for ti in decision.schedulable_tis:
+            with pytest.raises(ValueError) as ctx:
+                ti.run()
+            assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[int]"
+
     @pytest.mark.parametrize(
         "downstream, error_message",
         [
diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py
index 5807f1b296..144eb20327 100644
--- a/tests/models/test_xcom_arg_map.py
+++ b/tests/models/test_xcom_arg_map.py
@@ -128,9 +128,9 @@ def test_xcom_convert_to_kwargs_fails_task(dag_maker, session):
     tis[("pull", 1)].run()
 
     # But the third one fails because the map() result cannot be used as kwargs.
-    with pytest.raises(TypeError) as ctx:
+    with pytest.raises(ValueError) as ctx:
         tis[("pull", 2)].run()
-    assert str(ctx.value) == "'NoneType' object is not iterable"
+    assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[None]"
 
     assert [tis[("pull", i)].state for i in range(3)] == [
         TaskInstanceState.SUCCESS,