You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/07/14 05:28:08 UTC

[airflow] branch main updated: Implement expand_kwargs() (#24989)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d95bd9f41 Implement expand_kwargs() (#24989)
7d95bd9f41 is described below

commit 7d95bd9f416c9319f6b5c00058b0a1e3bd5bf805
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Thu Jul 14 13:28:01 2022 +0800

    Implement expand_kwargs() (#24989)
---
 airflow/decorators/base.py                    |  34 +++++--
 airflow/exceptions.py                         |  20 +++-
 airflow/models/baseoperator.py                |   3 +-
 airflow/models/expandinput.py                 |  95 ++++++++++++++++++-
 airflow/models/mappedoperator.py              |  27 +++---
 tests/models/test_dagrun.py                   |  21 +++++
 tests/models/test_mappedoperator.py           | 107 ++++++++++++++++++++-
 tests/models/test_taskinstance.py             |  90 ++++++++++++++++++
 tests/serialization/test_dag_serialization.py | 129 ++++++++++++++++++++++++++
 9 files changed, 496 insertions(+), 30 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 9598571240..a36d7d5d43 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -53,7 +53,12 @@ from airflow.models.baseoperator import (
     parse_retries,
 )
 from airflow.models.dag import DAG, DagContext
-from airflow.models.expandinput import EXPAND_INPUT_EMPTY, DictOfListsExpandInput, ExpandInput
+from airflow.models.expandinput import (
+    EXPAND_INPUT_EMPTY,
+    DictOfListsExpandInput,
+    ExpandInput,
+    ListOfDictsExpandInput,
+)
 from airflow.models.mappedoperator import (
     MappedOperator,
     ValidationSource,
@@ -171,8 +176,17 @@ class DecoratedOperator(BaseOperator):
         op_args = op_args or []
         op_kwargs = op_kwargs or {}
 
-        # Check that arguments can be binded
-        inspect.signature(python_callable).bind(*op_args, **op_kwargs)
+        # Check that arguments can be binded. There's a slight difference when
+        # we do validation for task-mapping: Since there's no guarantee we can
+        # receive enough arguments at parse time, we use bind_partial to simply
+        # check all the arguments we know are valid. Whether these are enough
+        # can only be known at execution time, when unmapping happens, and this
+        # is called without the _airflow_mapped_validation_only flag.
+        if kwargs.get("_airflow_mapped_validation_only"):
+            inspect.signature(python_callable).bind_partial(*op_args, **op_kwargs)
+        else:
+            inspect.signature(python_callable).bind(*op_args, **op_kwargs)
+
         self.multiple_outputs = multiple_outputs
         self.op_args = op_args
         self.op_kwargs = op_kwargs
@@ -323,6 +337,13 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
         # to False to skip the checks on execution.
         return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)
 
+    def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> XComArg:
+        from airflow.models.xcom_arg import XComArg
+
+        if not isinstance(kwargs, XComArg):
+            raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
+        return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
+
     def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
         ensure_xcomarg_return_value(expand_input.value)
 
@@ -442,10 +463,11 @@ class DecoratedMappedOperator(MappedOperator):
                 mapped_kwargs["op_kwargs"],
                 fail_reason="mapping already partial",
             )
+
+        static_kwargs = {k for k, _ in self.op_kwargs_expand_input.iter_parse_time_resolved_kwargs()}
         self._combined_op_kwargs = {**self.partial_kwargs["op_kwargs"], **mapped_kwargs["op_kwargs"]}
-        self._already_resolved_op_kwargs = {
-            k for k, v in self.op_kwargs_expand_input.value.items() if isinstance(v, XComArg)
-        }
+        self._already_resolved_op_kwargs = {k for k in mapped_kwargs["op_kwargs"] if k not in static_kwargs}
+
         kwargs = {
             "multiple_outputs": self.multiple_outputs,
             "python_callable": self.python_callable,
diff --git a/airflow/exceptions.py b/airflow/exceptions.py
index f1a8c1cb66..7a91100f11 100644
--- a/airflow/exceptions.py
+++ b/airflow/exceptions.py
@@ -113,12 +113,26 @@ class XComForMappingNotPushed(AirflowException):
 class UnmappableXComTypePushed(AirflowException):
     """Raise when an unmappable type is pushed as a mapped downstream's dependency."""
 
-    def __init__(self, value: Any) -> None:
-        super().__init__(value)
+    def __init__(self, value: Any, *values: Any) -> None:
+        super().__init__(value, *values)
+
+    def __str__(self) -> str:
+        typename = type(self.args[0]).__qualname__
+        for arg in self.args[1:]:
+            typename = f"{typename}[{type(arg).__qualname__}]"
+        return f"unmappable return type {typename!r}"
+
+
+class UnmappableXComValuePushed(AirflowException):
+    """Raise when an invalid value is pushed as a mapped downstream's dependency."""
+
+    def __init__(self, value: Any, reason: str) -> None:
+        super().__init__(value, reason)
         self.value = value
+        self.reason = reason
 
     def __str__(self) -> str:
-        return f"unmappable return type {type(self.value).__qualname__!r}"
+        return f"unmappable return value {self.value!r} ({self.reason})"
 
 
 class UnmappableXComLengthPushed(AirflowException):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 05ef24a124..2795a0f538 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -756,6 +756,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
 
         super().__init__()
 
+        kwargs.pop("_airflow_mapped_validation_only", None)
         if kwargs:
             if not conf.getboolean('operators', 'ALLOW_ILLEGAL_ARGUMENTS'):
                 raise AirflowException(
@@ -1509,7 +1510,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
     def validate_mapped_arguments(cls, **kwargs: Any) -> None:
         """Validate arguments when this operator is being mapped."""
         if cls.mapped_arguments_validated_by_init:
-            cls(**kwargs, _airflow_from_mapped=True)
+            cls(**kwargs, _airflow_from_mapped=True, _airflow_mapped_validation_only=True)
 
     def unmap(self, ctx: Union[None, Dict[str, Any], Tuple[Context, Session]]) -> "BaseOperator":
         """:meta private:"""
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 86623b41a1..b5b922f9df 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -22,26 +22,36 @@ import collections
 import collections.abc
 import functools
 import operator
-from typing import TYPE_CHECKING, Any, NamedTuple, Sequence, Union
+from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Sequence, Sized, Union
 
 from sqlalchemy import func
 from sqlalchemy.orm import Session
 
-from airflow.exceptions import UnmappableXComTypePushed
+from airflow.compat.functools import cache
+from airflow.exceptions import UnmappableXComTypePushed, UnmappableXComValuePushed
 from airflow.utils.context import Context
 
 if TYPE_CHECKING:
     from airflow.models.xcom_arg import XComArg
 
+ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]
+
 # BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any
 # mapping since we need the value to be ordered).
 Mappable = Union["XComArg", Sequence, dict]
 
-MAPPABLE_LITERAL_TYPES = (dict, list)
+
+# For isinstance() check.
+@cache
+def get_mappable_types() -> tuple[type, ...]:
+    from airflow.models.xcom_arg import XComArg
+
+    return (XComArg, list, tuple, dict)
 
 
 class NotFullyPopulated(RuntimeError):
     """Raise when ``get_map_lengths`` cannot populate all mapping metadata.
+
     This is generally due to not all upstream tasks have finished when the
     function is called.
     """
@@ -67,10 +77,20 @@ class DictOfListsExpandInput(NamedTuple):
         if not isinstance(value, collections.abc.Collection) or isinstance(value, (bytes, str)):
             raise UnmappableXComTypePushed(value)
 
+    def get_unresolved_kwargs(self) -> dict[str, Any]:
+        """Get the kwargs dict that can be inferred without resolving."""
+        return self.value
+
+    def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
+        """Generate kwargs with values available on parse-time."""
+        from airflow.models.xcom_arg import XComArg
+
+        return ((k, v) for k, v in self.value.items() if not isinstance(v, XComArg))
+
     def get_parse_time_mapped_ti_count(self) -> int | None:
         if not self.value:
             return 0
-        literal_values = [len(v) for v in self.value.values() if isinstance(v, MAPPABLE_LITERAL_TYPES)]
+        literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()]
         if len(literal_values) != len(self.value):
             return None  # None-literal type encountered, so give up.
         return functools.reduce(operator.mul, literal_values, 1)
@@ -184,12 +204,77 @@ class DictOfListsExpandInput(NamedTuple):
         return {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
 
 
-ExpandInput = DictOfListsExpandInput
+class ListOfDictsExpandInput(NamedTuple):
+    """Storage type of a mapped operator's mapped kwargs.
+
+    This is created from ``expand_kwargs(xcom_arg)``.
+    """
+
+    value: XComArg
+
+    @staticmethod
+    def validate_xcom(value: Any) -> None:
+        if not isinstance(value, collections.abc.Collection):
+            raise UnmappableXComTypePushed(value)
+        if isinstance(value, (str, bytes, collections.abc.Mapping)):
+            raise UnmappableXComTypePushed(value)
+        for item in value:
+            if not isinstance(item, collections.abc.Mapping):
+                raise UnmappableXComTypePushed(value, item)
+            if not all(isinstance(k, str) for k in item):
+                raise UnmappableXComValuePushed(value, reason="dict keys must be str")
+
+    def get_unresolved_kwargs(self) -> dict[str, Any]:
+        """Get the kwargs dict that can be inferred without resolving.
+
+        Since the list-of-dicts case relies entirely on run-time XCom, there's
+        no kwargs structure available, so this just returns an empty dict.
+        """
+        return {}
+
+    def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
+        return ()
+
+    def get_parse_time_mapped_ti_count(self) -> int | None:
+        return None
+
+    def get_total_map_length(self, run_id: str, *, session: Session) -> int:
+        from airflow.models.taskmap import TaskMap
+        from airflow.models.xcom import XCom
+
+        task = self.value.operator
+        if task.is_mapped:
+            query = session.query(func.count(XCom.map_index)).filter(
+                XCom.dag_id == task.dag_id,
+                XCom.run_id == run_id,
+                XCom.task_id == task.task_id,
+                XCom.map_index >= 0,
+            )
+        else:
+            query = session.query(TaskMap.length).filter(
+                TaskMap.dag_id == task.dag_id,
+                TaskMap.run_id == run_id,
+                TaskMap.task_id == task.task_id,
+                TaskMap.map_index < 0,
+            )
+        value = query.scalar()
+        if value is None:
+            raise NotFullyPopulated({"expand_kwargs() argument"})
+        return value
+
+    def resolve(self, context: Context, session: Session) -> dict[str, Any]:
+        map_index = context["ti"].map_index
+        if map_index < 0:
+            raise RuntimeError("can't resolve task-mapping argument without expanding")
+        # Validation should be done when the upstream returns.
+        return self.value.resolve(context, session)[map_index]
+
 
 EXPAND_INPUT_EMPTY = DictOfListsExpandInput({})  # Sentinel value.
 
 _EXPAND_INPUT_TYPES = {
     "dict-of-lists": DictOfListsExpandInput,
+    "list-of-dicts": ListOfDictsExpandInput,
 }
 
 
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index 9e75e9f4aa..a883ff2404 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -61,11 +61,12 @@ from airflow.models.abstractoperator import (
     TaskStateChangeCallback,
 )
 from airflow.models.expandinput import (
-    MAPPABLE_LITERAL_TYPES,
     DictOfListsExpandInput,
     ExpandInput,
+    ListOfDictsExpandInput,
     Mappable,
     NotFullyPopulated,
+    get_mappable_types,
 )
 from airflow.models.pool import Pool
 from airflow.serialization.enums import DagAttributeTypes
@@ -86,19 +87,12 @@ if TYPE_CHECKING:
     from airflow.models.dag import DAG
     from airflow.models.operator import Operator
     from airflow.models.taskinstance import TaskInstance
+    from airflow.models.xcom_arg import XComArg
     from airflow.utils.task_group import TaskGroup
 
 ValidationSource = Union[Literal["expand"], Literal["partial"]]
 
 
-# For isinstance() check.
-@cache
-def get_mappable_types() -> Tuple[type, ...]:
-    from airflow.models.xcom_arg import XComArg
-
-    return (XComArg,) + MAPPABLE_LITERAL_TYPES
-
-
 def validate_mapping_kwargs(op: Type["BaseOperator"], func: ValidationSource, value: Dict[str, Any]) -> None:
     # use a dict so order of args is same as code order
     unknown_args = value.copy()
@@ -198,6 +192,13 @@ class OperatorPartial:
         # to False to skip the checks on execution.
         return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)
 
+    def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> "MappedOperator":
+        from airflow.models.xcom_arg import XComArg
+
+        if not isinstance(kwargs, XComArg):
+            raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
+        return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)
+
     def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator":
         from airflow.operators.empty import EmptyOperator
 
@@ -541,12 +542,10 @@ class MappedOperator(AbstractOperator):
         operation on the list-of-dicts variant before execution time, an empty
         dict will be returned for this case.
         """
-        kwargs = self._get_specified_expand_input()
+        expand_input = self._get_specified_expand_input()
         if resolve is not None:
-            return kwargs.resolve(*resolve)
-        if isinstance(kwargs, DictOfListsExpandInput):
-            return kwargs.value
-        return {}
+            return expand_input.resolve(*resolve)
+        return expand_input.get_unresolved_kwargs()
 
     def _get_unmap_kwargs(self, mapped_kwargs: Dict[str, Any], *, strict: bool) -> Dict[str, Any]:
         """Get init kwargs to unmap the underlying operator class.
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index 17c339620b..8110240d4c 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -1422,3 +1422,24 @@ def test_schedule_tis_map_index(dag_maker, session):
     assert ti0.state == TaskInstanceState.SUCCESS
     assert ti1.state == TaskInstanceState.SCHEDULED
     assert ti2.state == TaskInstanceState.SUCCESS
+
+
+def test_mapped_expand_kwargs(dag_maker):
+    with dag_maker() as dag:
+
+        @task
+        def task_1():
+            return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}]
+
+        MockOperator.partial(task_id="task_2").expand_kwargs(task_1())
+
+    dr: DagRun = dag_maker.create_dagrun()
+    assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1
+
+    ti1 = dr.get_task_instance("task_1")
+    ti1.refresh_from_task(dag.get_task("task_1"))
+    ti1.run()
+
+    dr.task_instance_scheduling_decisions()
+    ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"}
+    assert ti_states == {0: None, 1: None, 2: None}
diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py
index 92d6226097..09ab87524b 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -272,5 +272,110 @@ def test_mapped_render_template_fields_validating_operator(dag_maker, session):
     assert isinstance(op, MyOperator)
 
     assert op.value == "{{ ds }}", "Should not be templated!"
-    assert op.arg1 == "{{ ds }}"
+    assert op.arg1 == "{{ ds }}", "Should not be templated!"
+    assert op.arg2 == "a"
+
+
+@pytest.mark.parametrize(
+    ["num_existing_tis", "expected"],
+    (
+        pytest.param(0, [(0, None), (1, None), (2, None)], id='only-unmapped-ti-exists'),
+        pytest.param(
+            3,
+            [(0, 'success'), (1, 'success'), (2, 'success')],
+            id='all-tis-exist',
+        ),
+        pytest.param(
+            5,
+            [
+                (0, 'success'),
+                (1, 'success'),
+                (2, 'success'),
+                (3, TaskInstanceState.REMOVED),
+                (4, TaskInstanceState.REMOVED),
+            ],
+            id="tis-to-be-removed",
+        ),
+    ),
+)
+def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis, expected):
+    literal = [{"arg1": "a"}, {"arg1": "b"}, {"arg1": "c"}]
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1))
+
+    dr = dag_maker.create_dagrun()
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=len(literal),
+            keys=None,
+        )
+    )
+
+    if num_existing_tis:
+        # Remove the map_index=-1 TI when we're creating other TIs
+        session.query(TaskInstance).filter(
+            TaskInstance.dag_id == mapped.dag_id,
+            TaskInstance.task_id == mapped.task_id,
+            TaskInstance.run_id == dr.run_id,
+        ).delete()
+
+    for index in range(num_existing_tis):
+        # Give the existing TIs a state to make sure we don't change them
+        ti = TaskInstance(mapped, run_id=dr.run_id, map_index=index, state=TaskInstanceState.SUCCESS)
+        session.add(ti)
+    session.flush()
+
+    mapped.expand_mapped_task(dr.run_id, session=session)
+
+    indices = (
+        session.query(TaskInstance.map_index, TaskInstance.state)
+        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, run_id=dr.run_id)
+        .order_by(TaskInstance.map_index)
+        .all()
+    )
+
+    assert indices == expected
+
+
+@pytest.mark.parametrize(
+    "map_index, expected",
+    [
+        pytest.param(0, "{{ ds }}", id="0"),
+        pytest.param(1, 2, id="1"),
+    ],
+)
+def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected):
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        mapped = MockOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand_kwargs(XComArg(task1))
+
+    dr = dag_maker.create_dagrun()
+    ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
+
+    ti.xcom_push(key=XCOM_RETURN_KEY, value=[{"arg1": '{{ ds }}'}, {"arg1": 2}], session=session)
+
+    session.add(
+        TaskMap(
+            dag_id=dr.dag_id,
+            task_id=task1.task_id,
+            run_id=dr.run_id,
+            map_index=-1,
+            length=2,
+            keys=None,
+        )
+    )
+    session.flush()
+
+    ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session)
+    ti.refresh_from_task(mapped)
+    ti.map_index = map_index
+    op = mapped.render_template_fields(context=ti.get_template_context(session=session))
+    assert isinstance(op, MockOperator)
+    assert op.arg1 == expected
     assert op.arg2 == "a"
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index e17f34bd78..05ac8daae5 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -43,6 +43,7 @@ from airflow.exceptions import (
     AirflowSkipException,
     UnmappableXComLengthPushed,
     UnmappableXComTypePushed,
+    UnmappableXComValuePushed,
     XComForMappingNotPushed,
 )
 from airflow.models import (
@@ -86,6 +87,7 @@ from tests.models import DEFAULT_DATE, TEST_DAGS_FOLDER
 from tests.test_utils import db
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_connections, clear_db_runs
+from tests.test_utils.mock_operators import MockOperator
 
 
 @pytest.fixture
@@ -2500,6 +2502,94 @@ class TestTaskInstanceRecordTaskMapXComPush:
         assert ti.state == TaskInstanceState.FAILED
         assert str(ctx.value) == error_message
 
+    @pytest.mark.parametrize(
+        "return_value, exception_type, error_message",
+        [
+            (123, UnmappableXComTypePushed, "unmappable return type 'int'"),
+            ([123], UnmappableXComTypePushed, "unmappable return type 'list[int]'"),
+            ([{1: 3}], UnmappableXComValuePushed, "unmappable return value [{1: 3}] (dict keys must be str)"),
+            (None, XComForMappingNotPushed, "did not push XCom for task mapping"),
+        ],
+    )
+    def test_expand_kwargs_error_if_unmappable_type(
+        self,
+        dag_maker,
+        return_value,
+        exception_type,
+        error_message,
+    ):
+        """If an unmappable return value is used for expand_kwargs(), fail the task that pushed the XCom."""
+        with dag_maker(dag_id="test_expand_kwargs_error_if_unmappable_type") as dag:
+
+            @dag.task()
+            def push():
+                return return_value
+
+            MockOperator.partial(task_id="pull").expand_kwargs(push())
+
+        ti = next(ti for ti in dag_maker.create_dagrun().task_instances if ti.task_id == "push")
+        with pytest.raises(exception_type) as ctx:
+            ti.run()
+
+        assert dag_maker.session.query(TaskMap).count() == 0
+        assert ti.state == TaskInstanceState.FAILED
+        assert str(ctx.value) == error_message
+
+    @pytest.mark.parametrize(
+        "downstream, error_message",
+        [
+            ("taskflow", "mapping already partial argument: arg2"),
+            ("classic", "unmappable or already specified argument: arg2"),
+        ],
+        ids=["taskflow", "classic"],
+    )
+    @pytest.mark.parametrize("strict", [True, False], ids=["strict", "override"])
+    def test_expand_kwargs_override_partial(self, dag_maker, session, downstream, error_message, strict):
+        class ClassicOperator(MockOperator):
+            def execute(self, context):
+                return (self.arg1, self.arg2)
+
+        with dag_maker(dag_id="test_expand_kwargs_override_partial", session=session) as dag:
+
+            @dag.task()
+            def push():
+                return [{"arg1": "a"}, {"arg1": "b", "arg2": "c"}]
+
+            push_task = push()
+
+            ClassicOperator.partial(task_id="classic", arg2="d").expand_kwargs(push_task, strict=strict)
+
+            @dag.task(task_id="taskflow")
+            def pull(arg1, arg2):
+                return (arg1, arg2)
+
+            pull.partial(arg2="d").expand_kwargs(push_task, strict=strict)
+
+        dr = dag_maker.create_dagrun()
+        next(ti for ti in dr.task_instances if ti.task_id == "push").run()
+
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        tis = {(ti.task_id, ti.map_index, ti.state): ti for ti in decision.schedulable_tis}
+        assert sorted(tis) == [
+            ("classic", 0, None),
+            ("classic", 1, None),
+            ("taskflow", 0, None),
+            ("taskflow", 1, None),
+        ]
+
+        ti = tis[((downstream, 0, None))]
+        ti.run()
+        ti.xcom_pull(task_ids=downstream, map_indexes=0, session=session) == ["a", "d"]
+
+        ti = tis[((downstream, 1, None))]
+        if strict:
+            with pytest.raises(TypeError) as ctx:
+                ti.run()
+            assert str(ctx.value) == error_message
+        else:
+            ti.run()
+            ti.xcom_pull(task_ids=downstream, map_indexes=1, session=session) == ["b", "c"]
+
     def test_error_if_upstream_does_not_push(self, dag_maker):
         """Fail the upstream task if it fails to push the XCom used for task mapping."""
         with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 418cb4af89..5751ae137c 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1777,6 +1777,52 @@ def test_operator_expand_xcomarg_serde():
     assert xcom_arg.operator is serialized_dag.task_dict['op1']
 
 
+@pytest.mark.parametrize("strict", [True, False])
+def test_operator_expand_kwargs_serde(strict):
+    from airflow.models.xcom_arg import XComArg
+
+    with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+        task1 = BaseOperator(task_id="op1")
+        mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1), strict=strict)
+
+    serialized = SerializedBaseOperator._serialize(mapped)
+    assert serialized == {
+        '_is_empty': False,
+        '_is_mapped': True,
+        '_task_module': 'tests.test_utils.mock_operators',
+        '_task_type': 'MockOperator',
+        'downstream_task_ids': [],
+        'expand_input': {
+            "type": "list-of-dicts",
+            "value": {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}},
+        },
+        'partial_kwargs': {},
+        'task_id': 'task_2',
+        'template_fields': ['arg1', 'arg2'],
+        'template_ext': [],
+        'template_fields_renderers': {},
+        'operator_extra_links': [],
+        'ui_color': '#fff',
+        'ui_fgcolor': '#000',
+        "_disallow_kwargs_override": strict,
+        '_expand_input_attr': 'expand_input',
+    }
+
+    op = SerializedBaseOperator.deserialize_operator(serialized)
+    assert op.deps is MappedOperator.deps_for(BaseOperator)
+    assert op._disallow_kwargs_override == strict
+
+    xcom_ref = op.expand_input.value
+    assert xcom_ref.task_id == 'op1'
+    assert xcom_ref.key == XCOM_RETURN_KEY
+
+    serialized_dag: DAG = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))
+
+    xcom_arg = serialized_dag.task_dict['task_2'].expand_input.value
+    assert isinstance(xcom_arg, XComArg)
+    assert xcom_arg.operator is serialized_dag.task_dict['op1']
+
+
 def test_operator_expand_deserialized_unmap():
     """Unmap a deserialized mapped operator should be similar to deserializing an non-mapped operator."""
     normal = BashOperator(task_id='a', bash_command=[1, 2], executor_config={"a": "b"})
@@ -1891,6 +1937,89 @@ def test_taskflow_expand_serde():
     }
 
 
+@pytest.mark.parametrize("strict", [True, False])
+def test_taskflow_expand_kwargs_serde(strict):
+    from airflow.decorators import task
+    from airflow.models.xcom_arg import XComArg
+    from airflow.serialization.serialized_objects import _ExpandInputRef, _XComRef
+
+    with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+        op1 = BaseOperator(task_id="op1")
+
+        @task(retry_delay=30)
+        def x(arg1, arg2, arg3):
+            print(arg1, arg2, arg3)
+
+        x.partial(arg1=[1, 2, {"a": "b"}]).expand_kwargs(XComArg(op1), strict=strict)
+
+    original = dag.get_task("x")
+
+    serialized = SerializedBaseOperator._serialize(original)
+    assert serialized == {
+        '_is_empty': False,
+        '_is_mapped': True,
+        '_task_module': 'airflow.decorators.python',
+        '_task_type': '_PythonDecoratedOperator',
+        'downstream_task_ids': [],
+        'partial_kwargs': {
+            'op_args': [],
+            'op_kwargs': {
+                '__type': 'dict',
+                '__var': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+            },
+            'retry_delay': {'__type': 'timedelta', '__var': 30.0},
+        },
+        'op_kwargs_expand_input': {
+            "type": "list-of-dicts",
+            "value": {
+                "__type": "xcomref",
+                "__var": {'task_id': 'op1', 'key': 'return_value'},
+            },
+        },
+        'operator_extra_links': [],
+        'ui_color': '#ffefeb',
+        'ui_fgcolor': '#000',
+        'task_id': 'x',
+        'template_ext': [],
+        'template_fields': ['op_args', 'op_kwargs'],
+        'template_fields_renderers': {"op_args": "py", "op_kwargs": "py"},
+        "_disallow_kwargs_override": strict,
+        '_expand_input_attr': 'op_kwargs_expand_input',
+    }
+
+    deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+    assert isinstance(deserialized, MappedOperator)
+    assert deserialized.deps is MappedOperator.deps_for(BaseOperator)
+    assert deserialized._disallow_kwargs_override == strict
+    assert deserialized.upstream_task_ids == set()
+    assert deserialized.downstream_task_ids == set()
+
+    assert deserialized.op_kwargs_expand_input == _ExpandInputRef(
+        key="list-of-dicts",
+        value=_XComRef("op1", XCOM_RETURN_KEY),
+    )
+    assert deserialized.partial_kwargs == {
+        "op_args": [],
+        "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
+        "retry_delay": timedelta(seconds=30),
+    }
+
+    # Ensure the serialized operator can also be correctly pickled, to ensure
+    # correct interaction between DAG pickling and serialization. This is done
+    # here so we don't need to duplicate tests between pickled and non-pickled
+    # DAGs everywhere else.
+    pickled = pickle.loads(pickle.dumps(deserialized))
+    assert pickled.op_kwargs_expand_input == _ExpandInputRef(
+        "list-of-dicts",
+        _XComRef("op1", XCOM_RETURN_KEY),
+    )
+    assert pickled.partial_kwargs == {
+        "op_args": [],
+        "op_kwargs": {"arg1": [1, 2, {"a": "b"}]},
+        "retry_delay": timedelta(seconds=30),
+    }
+
+
 @pytest.mark.filterwarnings("ignore::DeprecationWarning")
 @pytest.mark.parametrize(
     "is_inherit",