You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/14 05:28:08 UTC
[airflow] branch main updated: Implement expand_kwargs() (#24989)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7d95bd9f41 Implement expand_kwargs() (#24989)
7d95bd9f41 is described below
commit 7d95bd9f416c9319f6b5c00058b0a1e3bd5bf805
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Thu Jul 14 13:28:01 2022 +0800
Implement expand_kwargs() (#24989)
---
airflow/decorators/base.py | 34 +++++--
airflow/exceptions.py | 20 +++-
airflow/models/baseoperator.py | 3 +-
airflow/models/expandinput.py | 95 ++++++++++++++++++-
airflow/models/mappedoperator.py | 27 +++---
tests/models/test_dagrun.py | 21 +++++
tests/models/test_mappedoperator.py | 107 ++++++++++++++++++++-
tests/models/test_taskinstance.py | 90 ++++++++++++++++++
tests/serialization/test_dag_serialization.py | 129 ++++++++++++++++++++++++++
9 files changed, 496 insertions(+), 30 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 9598571240..a36d7d5d43 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -53,7 +53,12 @@ from airflow.models.baseoperator import (
parse_retries,
)
from airflow.models.dag import DAG, DagContext
-from airflow.models.expandinput import EXPAND_INPUT_EMPTY, DictOfListsExpandInput, ExpandInput
+from airflow.models.expandinput import (
+ EXPAND_INPUT_EMPTY,
+ DictOfListsExpandInput,
+ ExpandInput,
+ ListOfDictsExpandInput,
+)
from airflow.models.mappedoperator import (
MappedOperator,
ValidationSource,
@@ -171,8 +176,17 @@ class DecoratedOperator(BaseOperator):
op_args = op_args or []
op_kwargs = op_kwargs or {}
- # Check that arguments can be binded
- inspect.signature(python_callable).bind(*op_args, **op_kwargs)
+ # Check that arguments can be binded. There's a slight difference when
+ # we do validation for task-mapping: Since there's no guarantee we can
+ # receive enough arguments at parse time, we use bind_partial to simply
+ # check all the arguments we know are valid. Whether these are enough
+ # can only be known at execution time, when unmapping happens, and this
+ # is called without the _airflow_mapped_validation_only flag.
+ if kwargs.get("_airflow_mapped_validation_only"):
+ inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs)
+ else:
+ inspect.signature(python_callable).bind(*op_args, **op_kwargs)
+
self.multiple_outputs = multiple_outputs
self.op_args = op_args
self.op_kwargs = op_kwargs
@@ -323,6 +337,13 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
+ def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg:
+ from airflow.models.xcom_arg import XComArg
+
+ if not isinstance(kwargs, XComArg):
+ raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
+ return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
+
def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
ensure_xcomarg_return_value(expand_input.value)
@@ -442,10 +463,11 @@ class DecoratedMappedOperator(MappedOperator):
mapped_kwargs["op_kwargs"],
fail_reason="mapping already partial",
)
+
+ static_kwargs = {k for k, _ in self.op_kwargs_expand_input.iter_parse_time_resolved_kwargs()}
self._combined_op_kwargs = {**self.partial_kwargs["op_kwargs"], **mapped_kwargs["op_kwargs"]}
- self._already_resolved_op_kwargs = {
- k for k, v in self.op_kwargs_expand_input.value.items() if isinstance(v, XComArg)
- }
+ self._already_resolved_op_kwargs = {k for k in mapped_kwargs["op_kwargs"] if k not in static_kwargs}
+
kwargs = {
"multiple_outputs": self.multiple_outputs,
"python_callable": self.python_callable,
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index f1a8c1cb66..7a91100f11 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -113,12 +113,26 @@ class XComForMappingNotPushed(AirflowException):
class UnmappableXComTypePushed(AirflowException):
"""Raise when an unmappable type is pushed as a mapped downstream's dependency."""
- def __init__(self, value: Any) -> None:
- super().__init__(value)
+ def __init__(self, value: Any, *values: Any) -> None:
+ super().__init__(value, *values)
+
+ def __str__(self) -> str:
+ typename = type(self.args[0]).__qualname__
+ for arg in self.args[1:]:
+ typename = f"{typename}[{type(arg).__qualname__}]"
+ return f"unmappable return type {typename!r}"
+
+
+class UnmappableXComValuePushed(AirflowException):
+ """Raise when an invalid value is pushed as a mapped downstream's dependency."""
+
+ def __init__(self, value: Any, reason: str) -> None:
+ super().__init__(value, reason)
self.value = value
+ self.reason = reason
def __str__(self) -> str:
- return f"unmappable return type {type(self.value).__qualname__!r}"
+ return f"unmappable return value {self.value!r} ({self.reason})"
class UnmappableXComLengthPushed(AirflowException):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 05ef24a124..2795a0f538 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -756,6 +756,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
super().__init__()
+ kwargs.pop("_airflow_mapped_validation_only", None)
if kwargs:
if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
raise AirflowException(
@@ -1509,7 +1510,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
def validate_mapped_arguments(cls, **kwargs: Any) -> None:
"""Validate arguments when this operator is being mapped."""
if cls.mapped_arguments_validated_by_init:
- cls(**kwargs, _airflow_from_mapped=True)
+ cls(**kwargs, _airflow_from_mapped=True, _airflow_mapped_validation_only=True)
def unmap(self, ctx: Union[None, Dict[str, Any], Tuple[Context, Session]]) -> "BaseOperator":
""":meta private:"""
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 86623b41a1..b5b922f9df 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -22,26 +22,36 @@ import collections
import collections.abc
import functools
import operator
-from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union
+from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Sequence, Sized, Union
from sqlalchemy import func
from sqlalchemy.orm import Session
-from airflow.exceptions import UnmappableXComTypePushed
+from airflow.compat.functools import cache
+from airflow.exceptions import UnmappableXComTypePushed, UnmappableXComValuePushed
from airflow.utils.context import Context
if TYPE_CHECKING:
from airflow.models.xcom_arg import XComArg
+ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
+
# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any
# mapping since we need the value to be ordered).
Mappable = Union["XComArg", Sequence, dict]
-MAPPABLE_LITERAL_TYPES = (dict, list)
+
+# For isinstance() check.
+@cache
+def get_mappable_types() -> tuple[type, ...]:
+ from airflow.models.xcom_arg import XComArg
+
+ return (XComArg, list, tuple, dict)
class NotFullyPopulated(RuntimeError):
"""Raise when ``get_map_lengths`` cannot populate all mapping metadata.
+
This is generally due to not all upstream tasks have finished when the
function is called.
"""
@@ -67,10 +77,20 @@ class DictOfListsExpandInput(NamedTuple):
if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
raise UnmappableXComTypePushed(value)
+ def get_unresolved_kwargs(self) -> dict[str, Any]:
+ """Get the kwargs dict that can be inferred without resolving."""
+ return self.value
+
+ def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
+ """Generate kwargs with values available on parse-time."""
+ from airflow.models.xcom_arg import XComArg
+
+ return ((k, v) for k, v in self.value.items() if not isinstance(v, XComArg))
+
def get_parse_time_mapped_ti_count(self) -> int | None:
if not self.value:
return 0
- literal_values = [len(v) for v in self.value.values() if isinstance(v, MAPPABLE_LITERAL_TYPES)]
+ literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()]
if len(literal_values) != len(self.value):
return None # None-literal type encountered, so give up.
return functools.reduce(operator.mul, literal_values, 1)
@@ -184,12 +204,77 @@ class DictOfListsExpandInput(NamedTuple):
return {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
-ExpandInput = DictOfListsExpandInput
+class ListOfDictsExpandInput(NamedTuple):
+ """Storage type of a mapped operator's mapped kwargs.
+
+ This is created from ``expand_kwargs(xcom_arg)``.
+ """
+
+ value: XComArg
+
+ @staticmethod
+ def validate_xcom(value: Any) -> None:
+ if not isinstance(value, collections.abc.Collection):
+ raise UnmappableXComTypePushed(value)
+ if isinstance(value, (str, bytes, collections.abc.Mapping)):
+ raise UnmappableXComTypePushed(value)
+ for item in value:
+ if not isinstance(item, collections.abc.Mapping):
+ raise UnmappableXComTypePushed(value, item)
+ if not all(isinstance(k, str) for k in item):
+ raise UnmappableXComValuePushed(value, reason="dict keys must be str")
+
+ def get_unresolved_kwargs(self) -> dict[str, Any]:
+ """Get the kwargs dict that can be inferred without resolving.
+
+ Since the list-of-dicts case relies entirely on run-time XCom, there's
+ no kwargs structure available, so this just returns an empty dict.
+ """
+ return {}
+
+ def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
+ return ()
+
+ def get_parse_time_mapped_ti_count(self) -> int | None:
+ return None
+
+ def get_total_map_length(self, run_id: str, *, session: Session) -> int:
+ from airflow.models.taskmap import TaskMap
+ from airflow.models.xcom import XCom
+
+ task = self.value.operator
+ if task.is_mapped:
+ query = session.query(func.count(XCom.map_index)).filter(
+ XCom.dag_id == task.dag_id,
+ XCom.run_id == run_id,
+ XCom.task_id == task.task_id,
+ XCom.map_index >= 0,
+ )
+ else:
+ query = session.query(TaskMap.length).filter(
+ TaskMap.dag_id == task.dag_id,
+ TaskMap.run_id == run_id,
+ TaskMap.task_id == task.task_id,
+ TaskMap.map_index < 0,
+ )
+ value = query.scalar()
+ if value is None:
+ raise NotFullyPopulated({"expand_kwargs() argument"})
+ return value
+
+ def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+ map_index = context["ti"].map_index
+ if map_index < 0:
+ raise RuntimeError("can't resolve task-mapping argument without expanding")
+ # Validation should be done when the upstream returns.
+ return self.value.resolve(context, session)[map_index]
+
EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
_EXPAND_INPUT_TYPES = {
"dict-of-lists": DictOfListsExpandInput,
+ "list-of-dicts": ListOfDictsExpandInput,
}
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 9e75e9f4aa..a883ff2404 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -61,11 +61,12 @@ from airflow.models.abstractoperator import (
TaskStateChangeCallback,
)
from airflow.models.expandinput import (
- MAPPABLE_LITERAL_TYPES,
DictOfListsExpandInput,
ExpandInput,
+ ListOfDictsExpandInput,
Mappable,
NotFullyPopulated,
+ get_mappable_types,
)
from airflow.models.pool import Pool
from airflow.serialization.enums import DagAttributeTypes
@@ -86,19 +87,12 @@ if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
+ from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup
ValidationSource = Union[Literal["expand"], Literal["partial"]]
-# For isinstance() check.
-@cache
-def get_mappable_types() -> Tuple[type, ...]:
- from airflow.models.xcom_arg import XComArg
-
- return (XComArg,) + MAPPABLE_LITERAL_TYPES
-
-
def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, value: Dict[str, Any]) -> None:
# use a dict so order of args is same as code order
unknown_args = value.copy()
@@ -198,6 +192,13 @@ class OperatorPartial:
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
+ def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> "MappedOperator":
+ from airflow.models.xcom_arg import XComArg
+
+ if not isinstance(kwargs, XComArg):
+ raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
+ return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
+
def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator":
from airflow.operators.empty import EmptyOperator
@@ -541,12 +542,10 @@ class MappedOperator(AbstractOperator):
operation on the list-of-dicts variant before execution time, an empty
dict will be returned for this case.
"""
- kwargs = self._get_specified_expand_input()
+ expand_input = self._get_specified_expand_input()
if resolve is not None:
- return kwargs.resolve(*resolve)
- if isinstance(kwargs, DictOfListsExpandInput):
- return kwargs.value
- return {}
+ return expand_input.resolve(*resolve)
+ return expand_input.get_unresolved_kwargs()
def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
"""Get init kwargs to unmap the underlying operator class.
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 17c339620b..8110240d4c 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1422,3 +1422,24 @@ def test_schedule_tis_map_index(dag_maker, session):
assert ti0.state == TaskInstanceState.SUCCESS
assert ti1.state == TaskInstanceState.SCHEDULED
assert ti2.state == TaskInstanceState.SUCCESS
+
+
+def test_mapped_expand_kwargs(dag_maker):
+ with dag_maker() as dag:
+
+ @task
+ def task_1():
+ return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}]
+
+ MockOperator.partial(task_id="task_2").expand_kwargs(task_1())
+
+ dr: DagRun = dag_maker.create_dagrun()
+ assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1
+
+ ti1 = dr.get_task_instance("task_1")
+ ti1.refresh_from_task(dag.get_task("task_1"))
+ ti1.run()
+
+ dr.task_instance_scheduling_decisions()
+ ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"}
+ assert ti_states == {0: None, 1: None, 2: None}
diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py
index 92d6226097..09ab87524b 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -272,5 +272,110 @@ def test_mapped_render_template_fields_validating_operator(dag_maker, session):
assert isinstance(op, MyOperator)
assert op.value == "{{ ds }}", "Should not be templated!"
- assert op.arg1 == "{{ ds }}"
+ assert op.arg1 == "{{ ds }}", "Should not be templated!"
+ assert op.arg2 == "a"
+
+
+@pytest.mark.parametrize(
+ ["num_existing_tis", "expected"],
+ (
+ pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'),
+ pytest.param(
+ 3,
+ [(0, 'success'), (1, 'success'), (2, 'success')],
+ id='all-tis-exist',
+ ),
+ pytest.param(
+ 5,
+ [
+ (0, 'success'),
+ (1, 'success'),
+ (2, 'success'),
+ (3, TaskInstanceState.REMOVED),
+ (4, TaskInstanceState.REMOVED),
+ ],
+ id="tis-to-be-removed",
+ ),
+ ),
+)
+def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected):
+ literal = [{"arg1": "a"}, {"arg1": "b"}, {"arg1": "c"}]
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1))
+
+ dr = dag_maker.create_dagrun()
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=len(literal),
+ keys=None,
+ )
+ )
+
+ if num_existing_tis:
+ # Remove the map_index=-1 TI when we're creating other TIs
+ session.query(TaskInstance).filter(
+ TaskInstance.dag_id == mapped.dag_id,
+ TaskInstance.task_id == mapped.task_id,
+ TaskInstance.run_id == dr.run_id,
+ ).delete()
+
+ for index in range(num_existing_tis):
+ # Give the existing TIs a state to make sure we don't change them
+ ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS)
+ session.add(ti)
+ session.flush()
+
+ mapped.expand_mapped_task(dr.run_id, session=session)
+
+ indices = (
+ session.query(TaskInstance.map_index, TaskInstance.state)
+ .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
+ .order_by(TaskInstance.map_index)
+ .all()
+ )
+
+ assert indices == expected
+
+
+@pytest.mark.parametrize(
+ "map_index, expected",
+ [
+ pytest.param(0, "{{ ds }}", id="0"),
+ pytest.param(1, 2, id="1"),
+ ],
+)
+def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected):
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ mapped = MockOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand_kwargs(XComArg(task1))
+
+ dr = dag_maker.create_dagrun()
+ ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+ ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": '{{ ds }}'}, {"arg1": 2}], session=session)
+
+ session.add(
+ TaskMap(
+ dag_id=dr.dag_id,
+ task_id=task1.task_id,
+ run_id=dr.run_id,
+ map_index=-1,
+ length=2,
+ keys=None,
+ )
+ )
+ session.flush()
+
+ ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session)
+ ti.refresh_from_task(mapped)
+ ti.map_index = map_index
+ op = mapped.render_template_fields(context=ti.get_template_context(session=session))
+ assert isinstance(op, MockOperator)
+ assert op.arg1 == expected
assert op.arg2 == "a"
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index e17f34bd78..05ac8daae5 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -43,6 +43,7 @@ from airflow.exceptions import (
AirflowSkipException,
UnmappableXComLengthPushed,
UnmappableXComTypePushed,
+ UnmappableXComValuePushed,
XComForMappingNotPushed,
)
from airflow.models import (
@@ -86,6 +87,7 @@ from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER
from tests.test_utils import db
from tests.test_utils.config import conf_vars
from tests.test_utils.db import clear_db_connections, clear_db_runs
+from tests.test_utils.mock_operators import MockOperator
@pytest.fixture
@@ -2500,6 +2502,94 @@ class TestTaskInstanceRecordTaskMapXComPush:
assert ti.state == TaskInstanceState.FAILED
assert str(ctx.value) == error_message
+ @pytest.mark.parametrize(
+ "return_value, exception_type, error_message",
+ [
+ (123, UnmappableXComTypePushed, "unmappable return type 'int'"),
+ ([123], UnmappableXComTypePushed, "unmappable return type 'list[int]'"),
+ ([{1: 3}], UnmappableXComValuePushed, "unmappable return value [{1: 3}] (dict keys must be str)"),
+ (None, XComForMappingNotPushed, "did not push XCom for task mapping"),
+ ],
+ )
+ def test_expand_kwargs_error_if_unmappable_type(
+ self,
+ dag_maker,
+ return_value,
+ exception_type,
+ error_message,
+ ):
+ """If an unmappable return value is used for expand_kwargs(), fail the task that pushed the XCom."""
+ with dag_maker(dag_id="test_expand_kwargs_error_if_unmappable_type") as dag:
+
+ @dag.task()
+ def push():
+ return return_value
+
+ MockOperator.partial(task_id="pull").expand_kwargs(push())
+
+ ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push")
+ with pytest.raises(exception_type) as ctx:
+ ti.run()
+
+ assert dag_maker.session.query(TaskMap).count() == 0
+ assert ti.state == TaskInstanceState.FAILED
+ assert str(ctx.value) == error_message
+
+ @pytest.mark.parametrize(
+ "downstream, error_message",
+ [
+ ("taskflow", "mapping already partial argument: arg2"),
+ ("classic", "unmappable or already specified argument: arg2"),
+ ],
+ ids=["taskflow", "classic"],
+ )
+ @pytest.mark.parametrize("strict", [True, False], ids=["strict", "override"])
+ def test_expand_kwargs_override_partial(self, dag_maker, session, downstream, error_message, strict):
+ class ClassicOperator(MockOperator):
+ def execute(self, context):
+ return (self.arg1, self.arg2)
+
+ with dag_maker(dag_id="test_expand_kwargs_override_partial", session=session) as dag:
+
+ @dag.task()
+ def push():
+ return [{"arg1": "a"}, {"arg1": "b", "arg2": "c"}]
+
+ push_task = push()
+
+ ClassicOperator.partial(task_id="classic", arg2="d").expand_kwargs(push_task, strict=strict)
+
+ @dag.task(task_id="taskflow")
+ def pull(arg1, arg2):
+ return (arg1, arg2)
+
+ pull.partial(arg2="d").expand_kwargs(push_task, strict=strict)
+
+ dr = dag_maker.create_dagrun()
+ next(ti for ti in dr.task_instances if ti.task_id == "push").run()
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ tis = {(ti.task_id, ti.map_index, ti.state): ti for ti in decision.schedulable_tis}
+ assert sorted(tis) == [
+ ("classic", 0, None),
+ ("classic", 1, None),
+ ("taskflow", 0, None),
+ ("taskflow", 1, None),
+ ]
+
+ ti = tis[((downstream, 0, None))]
+ ti.run()
+ ti.xcom_pull(task_ids=downstream, map_indexes=0, session=session) == ["a", "d"]
+
+ ti = tis[((downstream, 1, None))]
+ if strict:
+ with pytest.raises(TypeError) as ctx:
+ ti.run()
+ assert str(ctx.value) == error_message
+ else:
+ ti.run()
+ ti.xcom_pull(task_ids=downstream, map_indexes=1, session=session) == ["b", "c"]
+
def test_error_if_upstream_does_not_push(self, dag_maker):
"""Fail the upstream task if it fails to push the XCom used for task mapping."""
with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 418cb4af89..5751ae137c 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1777,6 +1777,52 @@ def test_operator_expand_xcomarg_serde():
assert xcom_arg.operator is serialized_dag.task_dict['op1']
+@pytest.mark.parametrize("strict", [True, False])
+def test_operator_expand_kwargs_serde(strict):
+ from airflow.models.xcom_arg import XComArg
+
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+ task1 = BaseOperator(task_id="op1")
+ mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1), strict=strict)
+
+ serialized = SerializedBaseOperator._serialize(mapped)
+ assert serialized == {
+ '_is_empty': False,
+ '_is_mapped': True,
+ '_task_module': 'tests.test_utils.mock_operators',
+ '_task_type': 'MockOperator',
+ 'downstream_task_ids': [],
+ 'expand_input': {
+ "type": "list-of-dicts",
+ "value": {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}},
+ },
+ 'partial_kwargs': {},
+ 'task_id': 'task_2',
+ 'template_fields': ['arg1', 'arg2'],
+ 'template_ext': [],
+ 'template_fields_renderers': {},
+ 'operator_extra_links': [],
+ 'ui_color': '#fff',
+ 'ui_fgcolor': '#000',
+ "_disallow_kwargs_override": strict,
+ '_expand_input_attr': 'expand_input',
+ }
+
+ op = SerializedBaseOperator.deserialize_operator(serialized)
+ assert op.deps is MappedOperator.deps_for(BaseOperator)
+ assert op._disallow_kwargs_override == strict
+
+ xcom_ref = op.expand_input.value
+ assert xcom_ref.task_id == 'op1'
+ assert xcom_ref.key == XCOM_RETURN_KEY
+
+ serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+ xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value
+ assert isinstance(xcom_arg, XComArg)
+ assert xcom_arg.operator is serialized_dag.task_dict['op1']
+
+
def test_operator_expand_deserialized_unmap():
"""Unmap a deserialized mapped operator should be similar to deserializing an non-mapped operator."""
normal = BashOperator(task_id='a', bash_command=[1, 2], executor_config={"a": "b"})
@@ -1891,6 +1937,89 @@ def test_taskflow_expand_serde():
}
+@pytest.mark.parametrize("strict", [True, False])
+def test_taskflow_expand_kwargs_serde(strict):
+ from airflow.decorators import task
+ from airflow.models.xcom_arg import XComArg
+ from airflow.serialization.serialized_objects import _ExpandInputRef, _XComRef
+
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+ op1 = BaseOperator(task_id="op1")
+
+ @task(retry_delay=30)
+ def x(arg1, arg2, arg3):
+ print(arg1, arg2, arg3)
+
+ x.partial(arg1=[1, 2, {"a": "b"}]).expand_kwargs(XComArg(op1), strict=strict)
+
+ original = dag.get_task("x")
+
+ serialized = SerializedBaseOperator._serialize(original)
+ assert serialized == {
+ '_is_empty': False,
+ '_is_mapped': True,
+ '_task_module': 'airflow.decorators.python',
+ '_task_type': '_PythonDecoratedOperator',
+ 'downstream_task_ids': [],
+ 'partial_kwargs': {
+ 'op_args': [],
+ 'op_kwargs': {
+ '__type': 'dict',
+ '__var': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+ },
+ 'retry_delay': {'__type': 'timedelta', '__var': 30.0},
+ },
+ 'op_kwargs_expand_input': {
+ "type": "list-of-dicts",
+ "value": {
+ "__type": "xcomref",
+ "__var": {'task_id': 'op1', 'key': 'return_value'},
+ },
+ },
+ 'operator_extra_links': [],
+ 'ui_color': '#ffefeb',
+ 'ui_fgcolor': '#000',
+ 'task_id': 'x',
+ 'template_ext': [],
+ 'template_fields': ['op_args', 'op_kwargs'],
+ 'template_fields_renderers': {"op_args": "py", "op_kwargs": "py"},
+ "_disallow_kwargs_override": strict,
+ '_expand_input_attr': 'op_kwargs_expand_input',
+ }
+
+ deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ assert isinstance(deserialized, MappedOperator)
+ assert deserialized.deps is MappedOperator.deps_for(BaseOperator)
+ assert deserialized._disallow_kwargs_override == strict
+ assert deserialized.upstream_task_ids == set()
+ assert deserialized.downstream_task_ids == set()
+
+ assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
+ key="list-of-dicts",
+ value=_XComRef("op1", XCOM_RETURN_KEY),
+ )
+ assert deserialized.partial_kwargs == {
+ "op_args": [],
+ "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
+ "retry_delay": timedelta(seconds=30),
+ }
+
+ # Ensure the serialized operator can also be correctly pickled, to ensure
+ # correct interaction between DAG pickling and serialization. This is done
+ # here so we don't need to duplicate tests between pickled and non-pickled
+ # DAGs everywhere else.
+ pickled = pickle.loads(pickle.dumps(deserialized))
+ assert pickled.op_kwargs_expand_input == _ExpandInputRef(
+ "list-of-dicts",
+ _XComRef("op1", XCOM_RETURN_KEY),
+ )
+ assert pickled.partial_kwargs == {
+ "op_args": [],
+ "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
+ "retry_delay": timedelta(seconds=30),
+ }
+
+
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
@pytest.mark.parametrize(
"is_inherit",