You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/29 08:58:28 UTC
[airflow] branch main updated: Check expand_kwargs() input type before unmapping (#25355)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4e786e31bc Check expand_kwargs() input type before unmapping (#25355)
4e786e31bc is described below
commit 4e786e31bcdf81427163918e14d191e55a4ab606
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Fri Jul 29 16:58:18 2022 +0800
Check expand_kwargs() input type before unmapping (#25355)
---
airflow/decorators/base.py | 2 +-
airflow/models/expandinput.py | 27 ++++++++++++++++++++++-----
airflow/models/mappedoperator.py | 9 +++++----
tests/models/test_taskinstance.py | 36 ++++++++++++++++++++++++++++++++++++
tests/models/test_xcom_arg_map.py | 4 ++--
5 files changed, 66 insertions(+), 12 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 0a4b75cece..49b8f055dc 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -455,7 +455,7 @@ class DecoratedMappedOperator(MappedOperator):
assert self.expand_input is EXPAND_INPUT_EMPTY
return {"op_kwargs": super()._expand_mapped_kwargs(resolve)}
- def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
+ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> Dict[str, Any]:
if strict:
prevent_duplicates(
self.partial_kwargs["op_kwargs"],
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index f3fa84a302..5a94698f9b 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -22,7 +22,7 @@ import collections
import collections.abc
import functools
import operator
-from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Sequence, Sized, Union
+from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
from sqlalchemy import func
from sqlalchemy.orm import Session
@@ -195,10 +195,16 @@ class DictOfListsExpandInput(NamedTuple):
return k, v
raise IndexError(f"index {map_index} is over mapped length")
- def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+ def resolve(self, context: Context, session: Session) -> Mapping[str, Any]:
return {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
+def _describe_type(value: Any) -> str:
+ if value is None:
+ return "None"
+ return type(value).__name__
+
+
class ListOfDictsExpandInput(NamedTuple):
"""Storage type of a mapped operator's mapped kwargs.
@@ -245,12 +251,23 @@ class ListOfDictsExpandInput(NamedTuple):
raise NotFullyPopulated({"expand_kwargs() argument"})
return value
- def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+ def resolve(self, context: Context, session: Session) -> Mapping[str, Any]:
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
- # Validation should be done when the upstream returns.
- return self.value.resolve(context, session)[map_index]
+ mappings = self.value.resolve(context, session)
+ if not isinstance(mappings, collections.abc.Sequence):
+ raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
+ mapping = mappings[map_index]
+ if not isinstance(mapping, collections.abc.Mapping):
+ raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")
+ for key in mapping:
+ if not isinstance(key, str):
+ raise ValueError(
+ f"expand_kwargs() input dict keys must all be str, "
+ f"but {key!r} is of type {_describe_type(key)}"
+ )
+ return mapping
EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index dead62753c..53ea072a79 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -30,6 +30,7 @@ from typing import (
Iterable,
Iterator,
List,
+ Mapping,
Optional,
Sequence,
Set,
@@ -123,7 +124,7 @@ def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, va
raise TypeError(f"{op.__name__}.{func}() got {error}")
-def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Dict[str, Any], *, fail_reason: str) -> None:
+def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2: Mapping[str, Any], *, fail_reason: str) -> None:
duplicated_keys = set(kwargs1).intersection(kwargs2)
if not duplicated_keys:
return
@@ -528,7 +529,7 @@ class MappedOperator(AbstractOperator):
"""Implementing DAGNode."""
return DagAttributeTypes.OP, self.task_id
- def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, Session]]) -> Dict[str, Any]:
+ def _expand_mapped_kwargs(self, resolve: Optional[Tuple[Context, Session]]) -> Mapping[str, Any]:
"""Get the kwargs to create the unmapped operator.
If *resolve* is not *None*, it must be a two-tuple to provide context to
@@ -546,7 +547,7 @@ class MappedOperator(AbstractOperator):
return expand_input.resolve(*resolve)
return expand_input.get_unresolved_kwargs()
- def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
+ def _get_unmap_kwargs(self, mapped_kwargs: Mapping[str, Any], *, strict: bool) -> Dict[str, Any]:
"""Get init kwargs to unmap the underlying operator class.
:param mapped_kwargs: The dict returned by ``_expand_mapped_kwargs``.
@@ -569,7 +570,7 @@ class MappedOperator(AbstractOperator):
**mapped_kwargs,
}
- def unmap(self, resolve: Union[None, Dict[str, Any], Tuple[Context, Session]]) -> "BaseOperator":
+ def unmap(self, resolve: Union[None, Mapping[str, Any], Tuple[Context, Session]]) -> "BaseOperator":
"""Get the "normal" Operator after applying the current mapping.
If ``operator_class`` is not a class (i.e. this DAG has been
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 21badfe8cb..890e30f541 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -34,6 +34,7 @@ import pytest
from freezegun import freeze_time
from airflow import models, settings
+from airflow.decorators import task
from airflow.example_dags.plugins.workday import AfterWorkdayTimetable
from airflow.exceptions import (
AirflowException,
@@ -2640,6 +2641,41 @@ class TestTaskInstanceRecordTaskMapXComPush:
assert ti.state == TaskInstanceState.FAILED
assert str(ctx.value) == error_message
+ @pytest.mark.parametrize(
+ "create_upstream",
+ [
+ # The task returns an invalid expand_kwargs() input (a list[int] instead of list[dict]).
+ pytest.param(lambda: task(task_id="push")(lambda: [0])(), id="normal"),
+ # This task returns a list[dict] (correct), but we use map() to transform it to list[int] (wrong).
+ pytest.param(lambda: task(task_id="push")(lambda: [{"v": ""}])().map(lambda _: 0), id="mapped"),
+ ],
+ )
+ def test_expand_kwargs_error_if_received_invalid(self, dag_maker, session, create_upstream):
+ with dag_maker(dag_id="test_expand_kwargs_error_if_received_invalid", session=session):
+ push_task = create_upstream()
+
+ @task()
+ def pull(v):
+ print(v)
+
+ pull.expand_kwargs(push_task)
+
+ dr = dag_maker.create_dagrun()
+
+ # Run "push".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert decision.schedulable_tis
+ for ti in decision.schedulable_tis:
+ ti.run()
+
+ # Run "pull".
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ assert decision.schedulable_tis
+ for ti in decision.schedulable_tis:
+ with pytest.raises(ValueError) as ctx:
+ ti.run()
+ assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[int]"
+
@pytest.mark.parametrize(
"downstream, error_message",
[
diff --git a/tests/models/test_xcom_arg_map.py b/tests/models/test_xcom_arg_map.py
index 5807f1b296..144eb20327 100644
--- a/tests/models/test_xcom_arg_map.py
+++ b/tests/models/test_xcom_arg_map.py
@@ -128,9 +128,9 @@ def test_xcom_convert_to_kwargs_fails_task(dag_maker, session):
tis[("pull", 1)].run()
# But the third one fails because the map() result cannot be used as kwargs.
- with pytest.raises(TypeError) as ctx:
+ with pytest.raises(ValueError) as ctx:
tis[("pull", 2)].run()
- assert str(ctx.value) == "'NoneType' object is not iterable"
+ assert str(ctx.value) == "expand_kwargs() expects a list[dict], not list[None]"
assert [tis[("pull", i)].state for i in range(3)] == [
TaskInstanceState.SUCCESS,