You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/02/10 07:08:46 UTC

[airflow] branch main updated: Rewrite decorated task mapping (#21328)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new fded2ca  Rewrite decorated task mapping (#21328)
fded2ca is described below

commit fded2ca0b9c995737b401896b89e5c9fd7f24c91
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Thu Feb 10 15:07:50 2022 +0800

    Rewrite decorated task mapping (#21328)
---
 airflow/decorators/base.py                    | 92 ++++++++++++++++++++++-----
 airflow/models/baseoperator.py                | 55 +++-------------
 airflow/models/taskinstance.py                |  2 +-
 airflow/serialization/serialized_objects.py   | 27 ++++++--
 tests/dags/test_mapped_taskflow.py            | 31 +++++++++
 tests/decorators/test_python.py               | 69 ++++++++++++++++----
 tests/jobs/test_backfill_job.py               | 16 +++--
 tests/serialization/test_dag_serialization.py | 53 +++++++++++++++
 8 files changed, 259 insertions(+), 86 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 9cf423f..53a12c6 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -280,30 +280,88 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
             names = ", ".join(repr(n) for n in unknown_args)
             raise TypeError(f'{funcname} got unexpected keyword arguments {names}')
 
-    def map(
-        self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
-    ) -> XComArg:
+    def map(self, *args, **kwargs) -> XComArg:
         self._validate_arg_names("map", kwargs)
-        dag = dag or DagContext.get_current_dag()
-        task_group = task_group or TaskGroupContext.get_current_task_group(dag)
-        task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)
 
-        operator = MappedOperator.from_decorator(
-            decorator=self,
+        partial_kwargs = self.kwargs.copy()
+        dag = partial_kwargs.pop("dag", DagContext.get_current_dag())
+        task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag))
+        task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
+
+        # Unfortunately attrs's type hinting support does not work well with
+        # subclassing; it complains that arguments forwarded to the superclass
+        # are "unexpected" (they are fine at runtime).
+        operator = cast(Any, DecoratedMappedOperator)(
+            operator_class=self.operator_class,
+            partial_kwargs=partial_kwargs,
+            mapped_kwargs={},
+            task_id=task_id,
             dag=dag,
             task_group=task_group,
-            task_id=task_id,
-            mapped_kwargs=kwargs,
+            deps=MappedOperator._deps(self.operator_class.deps),
+            multiple_outputs=self.multiple_outputs,
+            python_callable=self.function,
         )
+
+        operator.mapped_kwargs["op_args"] = list(args)
+        operator.mapped_kwargs["op_kwargs"] = kwargs
+
+        for arg in itertools.chain(args, kwargs.values()):
+            XComArg.apply_upstream_relationship(operator, arg)
         return XComArg(operator=operator)
 
-    def partial(
-        self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
-    ) -> "_TaskDecorator[Function, OperatorSubclass]":
-        self._validate_arg_names("partial", kwargs, {'task_id'})
-        partial_kwargs = self.kwargs.copy()
-        partial_kwargs.update(kwargs)
-        return attr.evolve(self, kwargs=partial_kwargs)
+    def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]":
+        self._validate_arg_names("partial", kwargs)
+
+        op_args = self.kwargs.get("op_args", [])
+        op_args.extend(args)
+
+        op_kwargs = self.kwargs.get("op_kwargs", {})
+        op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial")
+
+        return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs})
+
+
+def _merge_kwargs(
+    kwargs1: Dict[str, XComArg],
+    kwargs2: Dict[str, XComArg],
+    *,
+    fail_reason: str,
+) -> Dict[str, XComArg]:
+    duplicated_keys = set(kwargs1).intersection(kwargs2)
+    if len(duplicated_keys) == 1:
+        raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
+    elif duplicated_keys:
+        duplicated_keys_display = ", ".join(sorted(duplicated_keys))
+        raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
+    return {**kwargs1, **kwargs2}
+
+
+@attr.define(kw_only=True)
+class DecoratedMappedOperator(MappedOperator):
+    """MappedOperator implementation for @task-decorated task function."""
+
+    multiple_outputs: bool
+    python_callable: Callable
+
+    def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
+        assert not isinstance(self.operator_class, str)
+        op_args = self.partial_kwargs.pop("op_args", []) + self.mapped_kwargs.pop("op_args", [])
+        op_kwargs = _merge_kwargs(
+            self.partial_kwargs.pop("op_kwargs", {}),
+            self.mapped_kwargs.pop("op_kwargs", {}),
+            fail_reason="mapping already partial",
+        )
+        return self.operator_class(
+            dag=dag,
+            task_id=self.task_id,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            multiple_outputs=self.multiple_outputs,
+            python_callable=self.python_callable,
+            **self.partial_kwargs,
+            **self.mapped_kwargs,
+        )
 
 
 class Task(Generic[Function]):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 8f96153..35a0fbb 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -82,7 +82,6 @@ from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.weight_rule import WeightRule
 
 if TYPE_CHECKING:
-    from airflow.decorators.base import _TaskDecorator
     from airflow.models.dag import DAG
     from airflow.utils.task_group import TaskGroup
 
@@ -243,7 +242,7 @@ class BaseOperatorMeta(abc.ABCMeta):
         return new_cls
 
     # The class level partial function. This is what handles the actual mapping
-    def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs):
+    def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs) -> "MappedOperator":
         operator_class = cast("Type[BaseOperator]", cls)
         # Validate that the args we passed are known -- at call/DAG parse time, not run time!
         _validate_kwarg_names_for_mapping(operator_class, "partial", kwargs)
@@ -1632,7 +1631,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
             dag._remove_task(operator.task_id)
 
         operator_init_kwargs: dict = operator._BaseOperator__init_kwargs  # type: ignore
-        return MappedOperator(
+        return cls(
             operator_class=type(operator),
             task_id=operator.task_id,
             task_group=task_group,
@@ -1649,37 +1648,6 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
         )
 
     @classmethod
-    def from_decorator(
-        cls,
-        *,
-        decorator: "_TaskDecorator",
-        dag: Optional["DAG"],
-        task_group: Optional["TaskGroup"],
-        task_id: str,
-        mapped_kwargs: Dict[str, Any],
-    ) -> "MappedOperator":
-        """Create a mapped operator from a task decorator.
-
-        Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``.
-        The task decorator calling this should be responsible for validation.
-        """
-        from airflow.models.xcom_arg import XComArg
-
-        operator = MappedOperator(
-            operator_class=decorator.operator_class,
-            partial_kwargs=decorator.kwargs,
-            mapped_kwargs={},
-            task_id=task_id,
-            dag=dag,
-            task_group=task_group,
-            deps=cls._deps(decorator.operator_class.deps),
-        )
-        operator.mapped_kwargs.update(mapped_kwargs)
-        for arg in mapped_kwargs.values():
-            XComArg.apply_upstream_relationship(operator, arg)
-        return operator
-
-    @classmethod
     def _deps(cls, deps: Iterable[BaseTIDep]):
         if deps is BaseOperator.deps:
             return cls.DEFAULT_DEPS
@@ -1749,7 +1717,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
     @classmethod
     def get_serialized_fields(cls):
         if cls.__serialized_fields is None:
-            fields_dict = attr.fields_dict(cls)
+            fields_dict = attr.fields_dict(MappedOperator)
             cls.__serialized_fields = frozenset(
                 fields_dict.keys()
                 - {
@@ -1902,22 +1870,17 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
 
         return ret
 
-    def unmap(self) -> BaseOperator:
-        """Get the "normal" Operator after applying the current mapping"""
+    def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
         assert not isinstance(self.operator_class, str)
+        return self.operator_class(dag=dag, task_id=self.task_id, **self.partial_kwargs, **self.mapped_kwargs)
 
+    def unmap(self) -> BaseOperator:
+        """Get the "normal" Operator after applying the current mapping"""
         dag = self.get_dag()
         if not dag:
-            raise RuntimeError("Cannot unmapp a task unless it has a dag")
-
-        args = {
-            **self.partial_kwargs,
-            **self.mapped_kwargs,
-        }
+            raise RuntimeError("Cannot unmap a task unless it has a DAG")
         dag._remove_task(self.task_id)
-        task = self.operator_class(task_id=self.task_id, dag=self.dag, **args)
-
-        return task
+        return self.create_unmapped_operator(dag)
 
 
 # TODO: Deprecate for Airflow 3.0
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4996b9a..f10032d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1713,7 +1713,7 @@ class TaskInstance(Base, LoggingMixin):
         test_mode: Optional[bool] = None,
         force_fail: bool = False,
         error_file: Optional[str] = None,
-        session=NEW_SESSION,
+        session: Session = NEW_SESSION,
     ) -> None:
         """Handle Failure for the TaskInstance"""
         if test_mode is None:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index d6abda7..017f227 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -16,6 +16,7 @@
 # under the License.
 
 """Serialized DAG and BaseOperator"""
+import contextlib
 import datetime
 import enum
 import logging
@@ -168,7 +169,7 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable:
     return timetable_class.deserialize(var[Encoding.VAR])
 
 
-class _XcomRef(NamedTuple):
+class _XComRef(NamedTuple):
     """
     Used to store info needed to create XComArg when deserializing MappedOperator.
 
@@ -497,8 +498,8 @@ class BaseSerialization:
         return {"key": arg.key, "task_id": arg.operator.task_id}
 
     @classmethod
-    def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef:
-        return _XcomRef(key=encoded['key'], task_id=encoded['task_id'])
+    def _deserialize_xcomref(cls, encoded: dict) -> _XComRef:
+        return _XComRef(key=encoded['key'], task_id=encoded['task_id'])
 
 
 class DependencyDetector:
@@ -566,9 +567,19 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
 
     @classmethod
     def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
-
         stock_deps = op.deps is MappedOperator.DEFAULT_DEPS
         serialize_op = cls._serialize_node(op, include_deps=not stock_deps)
+
+        # Simplify op_kwargs format. It must be a dict, so we flatten it.
+        with contextlib.suppress(KeyError):
+            op_kwargs = serialize_op["mapped_kwargs"]["op_kwargs"]
+            assert op_kwargs[Encoding.TYPE] == DAT.DICT
+            serialize_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
+        with contextlib.suppress(KeyError):
+            op_kwargs = serialize_op["partial_kwargs"]["op_kwargs"]
+            assert op_kwargs[Encoding.TYPE] == DAT.DICT
+            serialize_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
+
         # It must be a class at this point for it to work, not a string
         assert isinstance(op.operator_class, type)
         serialize_op['_task_type'] = op.operator_class.__name__
@@ -715,7 +726,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
             elif k == "params":
                 v = cls._deserialize_params_dict(v)
             elif k in ("mapped_kwargs", "partial_kwargs"):
+                if "op_kwargs" not in v:
+                    op_kwargs: Optional[dict] = None
+                else:
+                    op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()}
                 v = {arg: cls._deserialize(value) for arg, value in v.items()}
+                if op_kwargs is not None:
+                    v["op_kwargs"] = op_kwargs
             elif k in cls._decorated_fields or k not in op.get_serialized_fields():
                 v = cls._deserialize(v)
             # else use v as it is
@@ -1002,7 +1019,7 @@ class SerializedDAG(DAG, BaseSerialization):
             if isinstance(task, MappedOperator):
                 for d in (task.mapped_kwargs, task.partial_kwargs):
                     for k, v in d.items():
-                        if not isinstance(v, _XcomRef):
+                        if not isinstance(v, _XComRef):
                             continue
 
                         d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key)
diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py
new file mode 100644
index 0000000..f21a9a5
--- /dev/null
+++ b/tests/dags/test_mapped_taskflow.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import DAG
+from airflow.utils.dates import days_ago
+
+with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag:
+
+    @dag.task
+    def make_list():
+        return [1, 2, {'a': 'b'}]
+
+    @dag.task
+    def consumer(value):
+        print(repr(value))
+
+    consumer.map(value=make_list())
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 0c93b49..ee94fde 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -17,7 +17,7 @@
 # under the License.
 import sys
 from collections import namedtuple
-from datetime import date, timedelta
+from datetime import date, datetime, timedelta
 from typing import Dict  # noqa: F401  # This is used by annotation tests.
 from typing import Tuple
 
@@ -490,7 +490,7 @@ def test_mapped_decorator() -> None:
     assert isinstance(doubled_0, XComArg)
     assert isinstance(doubled_0.operator, MappedOperator)
     assert doubled_0.operator.task_id == "double"
-    assert doubled_0.operator.mapped_kwargs == {"number": literal}
+    assert doubled_0.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}}
 
     assert doubled_1.operator.task_id == "double__1"
 
@@ -514,25 +514,68 @@ def test_partial_mapped_decorator() -> None:
     def product(number: int, multiple: int):
         return number * multiple
 
+    literal = [1, 2, 3]
+
     with DAG('test_dag', start_date=DEFAULT_DATE) as dag:
-        literal = [1, 2, 3]
-        quadrupled = product.partial(task_id='times_4', multiple=3).map(number=literal)
+        quadrupled = product.partial(multiple=3).map(number=literal)
         doubled = product.partial(multiple=2).map(number=literal)
         trippled = product.partial(multiple=3).map(number=literal)
 
-        product.partial(multiple=2)
+        product.partial(multiple=2)  # No operator is actually created.
+
+    assert dag.task_dict == {
+        "product": quadrupled.operator,
+        "product__1": doubled.operator,
+        "product__2": trippled.operator,
+    }
 
     assert isinstance(doubled, XComArg)
     assert isinstance(doubled.operator, MappedOperator)
-    assert doubled.operator.task_id == "product"
-    assert doubled.operator.mapped_kwargs == {"number": literal}
-    assert doubled.operator.partial_kwargs == {"task_id": "product", "multiple": 2}
+    assert doubled.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}}
+    assert doubled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 2}}
 
-    assert trippled.operator.task_id == "product__1"
-    assert trippled.operator.partial_kwargs == {"task_id": "product", "multiple": 3}
-
-    assert quadrupled.operator.task_id == "times_4"
+    assert isinstance(trippled.operator, MappedOperator)  # For type-checking on partial_kwargs.
+    assert trippled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 3}}
 
     assert doubled.operator is not trippled.operator
 
-    assert [quadrupled.operator, doubled.operator, trippled.operator] == dag.tasks
+
+def test_mapped_decorator_unmap_merge_op_kwargs():
+    with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+
+        @task_decorator
+        def task1():
+            ...
+
+        @task_decorator
+        def task2(arg1, arg2):
+            ...
+
+        task2.partial(arg1=1).map(arg2=task1())
+
+    unmapped = dag.get_task("task2").unmap()
+    assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
+
+
+def test_mapped_decorator_unmap_converts_partial_kwargs():
+    with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+
+        @task_decorator
+        def task1(arg):
+            ...
+
+        @task_decorator(retry_delay=30)
+        def task2(arg1, arg2):
+            ...
+
+        task2.partial(arg1=1).map(arg2=task1.map(arg=[1, 2]))
+
+    # Arguments to the task decorator are stored in partial_kwargs, and
+    # converted into their intended form after the task is unmapped.
+    mapped_task2 = dag.get_task("task2")
+    assert mapped_task2.partial_kwargs["retry_delay"] == 30
+    assert mapped_task2.unmap().retry_delay == timedelta(seconds=30)
+
+    mapped_task1 = dag.get_task("task1")
+    assert "retry_delay" not in mapped_task1.partial_kwargs
+    mapped_task1.unmap().retry_delay == timedelta(seconds=300)  # Operator default.
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 0878f63..40593d5 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -47,7 +47,13 @@ from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.types import DagRunType
 from tests.models import TEST_DAGS_FOLDER
-from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots
+from tests.test_utils.db import (
+    clear_db_dags,
+    clear_db_pools,
+    clear_db_runs,
+    clear_db_xcom,
+    set_default_pool_slots,
+)
 from tests.test_utils.mock_executor import MockExecutor
 from tests.test_utils.timetables import cron_timetable
 
@@ -66,6 +72,7 @@ class TestBackfillJob:
     def clean_db():
         clear_db_dags()
         clear_db_runs()
+        clear_db_xcom()
         clear_db_pools()
 
     @pytest.fixture(autouse=True)
@@ -1512,13 +1519,14 @@ class TestBackfillJob:
         job.run()
         assert executor.job_id is not None
 
-    def test_mapped_dag(self, dag_maker):
+    @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"])
+    def test_mapped_dag(self, dag_id):
         """End-to-end test of a simple mapped dag"""
         # Use SequentialExecutor for more predictable test behaviour
         from airflow.executors.sequential_executor import SequentialExecutor
 
-        self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py'))
-        dag = self.dagbag.get_dag('test_mapped_classic')
+        self.dagbag.process_file(str(TEST_DAGS_FOLDER / f'{dag_id}.py'))
+        dag = self.dagbag.get_dag(dag_id)
 
         # This needs a real executor to run, so that the `make_list` task can write out the TaskMap
 
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 447b173..1e8d510 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1654,6 +1654,59 @@ def test_mapped_operator_xcomarg_serde():
     assert xcom_arg.operator is serialized_dag.task_dict['op1']
 
 
+def test_mapped_decorator_serde():
+    from airflow.decorators import task
+    from airflow.models.xcom_arg import XComArg
+    from airflow.serialization.serialized_objects import _XComRef
+
+    with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+        op1 = BaseOperator(task_id="op1")
+        xcomarg = XComArg(op1, "my_key")
+
+        @task(retry_delay=30)
+        def x(arg1, arg2, arg3, arg4):
+            print(arg1, arg2, arg3, arg4)
+
+        x.partial("foo", arg3=[1, 2, {"a": "b"}]).map({"a": 1, "b": 2}, arg4=xcomarg)
+
+    original = dag.get_task("x")
+
+    serialized = SerializedBaseOperator._serialize(original)
+    assert serialized == {
+        '_is_dummy': False,
+        '_is_mapped': True,
+        '_task_module': 'airflow.decorators.python',
+        '_task_type': '_PythonDecoratedOperator',
+        'downstream_task_ids': [],
+        'partial_kwargs': {
+            'op_args': ["foo"],
+            'op_kwargs': {'arg3': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+            'retry_delay': 30,
+        },
+        'mapped_kwargs': {
+            'op_args': [{"__type": "dict", "__var": {'a': 1, 'b': 2}}],
+            'op_kwargs': {'arg4': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'my_key'}}},
+        },
+        'task_id': 'x',
+        'template_ext': [],
+        'template_fields': ['op_args', 'op_kwargs'],
+    }
+
+    deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+    assert isinstance(deserialized, MappedOperator)
+    assert deserialized.deps is MappedOperator.DEFAULT_DEPS
+
+    assert deserialized.mapped_kwargs == {
+        "op_args": [{"a": 1, "b": 2}],
+        "op_kwargs": {"arg4": _XComRef("op1", "my_key")},
+    }
+    assert deserialized.partial_kwargs == {
+        "retry_delay": 30,
+        "op_args": ["foo"],
+        "op_kwargs": {"arg3": [1, 2, {"a": "b"}]},
+    }
+
+
 def test_mapped_task_group_serde():
     execution_date = datetime(2020, 1, 1)