You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/06/22 07:48:57 UTC

[airflow] branch main updated: Remove special serde logic for mapped op_kwargs (#23860)

This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5877f45d65 Remove special serde logic for mapped op_kwargs (#23860)
5877f45d65 is described below

commit 5877f45d65d5aa864941efebd2040661b6f89cb1
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Jun 22 15:48:50 2022 +0800

    Remove special serde logic for mapped op_kwargs (#23860)
    
    Co-authored-by: Daniel Standish <15...@users.noreply.github.com>
---
 airflow/decorators/base.py                    | 10 +--------
 airflow/models/mappedoperator.py              |  1 +
 airflow/serialization/serialized_objects.py   | 29 +++++----------------------
 tests/serialization/test_dag_serialization.py | 29 +++++++++++++++------------
 4 files changed, 23 insertions(+), 46 deletions(-)

diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 1b5b5b760b..2a2ce2da96 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -39,7 +39,7 @@ from typing import (
 import attr
 import typing_extensions
 
-from airflow.compat.functools import cache, cached_property
+from airflow.compat.functools import cached_property
 from airflow.exceptions import AirflowException
 from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
 from airflow.models.baseoperator import (
@@ -427,14 +427,6 @@ class DecoratedMappedOperator(MappedOperator):
     def __hash__(self):
         return id(self)
 
-    @classmethod
-    @cache
-    def get_serialized_fields(cls):
-        # The magic super() doesn't work here, so we use the explicit form.
-        # Not using super(..., cls) to work around pyupgrade bug.
-        sup = super(DecoratedMappedOperator, DecoratedMappedOperator)
-        return sup.get_serialized_fields() | {"mapped_op_kwargs"}
-
     def __attrs_post_init__(self):
         # The magic super() doesn't work here, so we use the explicit form.
         # Not using super(..., self) to work around pyupgrade bug.
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index ba7328d2f7..21a265e6e9 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -324,6 +324,7 @@ class MappedOperator(AbstractOperator):
             "dag",
             "deps",
             "is_mapped",
+            "mapped_kwargs",  # This is needed to be able to accept XComArg.
             "subdag",
             "task_group",
             "upstream_task_ids",
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 85286688f5..bd0430ba26 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -16,7 +16,7 @@
 # under the License.
 
 """Serialized DAG and BaseOperator"""
-import contextlib
+
 import datetime
 import enum
 import logging
@@ -593,6 +593,9 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
     def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
         serialized_op = cls._serialize_node(op, include_deps=op.deps is MappedOperator.deps_for(BaseOperator))
 
+        # Handle mapped_kwargs and mapped_op_kwargs.
+        serialized_op[op._expansion_kwargs_attr] = cls._serialize(op._get_expansion_kwargs())
+
         # Simplify partial_kwargs by comparing it to the most barebone object.
         # Remove all entries that are simply default values.
         serialized_partial = serialized_op["partial_kwargs"]
@@ -604,20 +607,6 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
             if v == default:
                 del serialized_partial[k]
 
-        # Simplify op_kwargs format. It must be a dict, so we flatten it.
-        with contextlib.suppress(KeyError):
-            op_kwargs = serialized_op["mapped_kwargs"]["op_kwargs"]
-            assert op_kwargs[Encoding.TYPE] == DAT.DICT
-            serialized_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
-        with contextlib.suppress(KeyError):
-            op_kwargs = serialized_op["partial_kwargs"]["op_kwargs"]
-            assert op_kwargs[Encoding.TYPE] == DAT.DICT
-            serialized_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
-        with contextlib.suppress(KeyError):
-            op_kwargs = serialized_op["mapped_op_kwargs"]
-            assert op_kwargs[Encoding.TYPE] == DAT.DICT
-            serialized_op["mapped_op_kwargs"] = op_kwargs[Encoding.VAR]
-
         serialized_op["_is_mapped"] = True
         return serialized_op
 
@@ -753,15 +742,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
                 v = cls._deserialize_deps(v)
             elif k == "params":
                 v = cls._deserialize_params_dict(v)
-            elif k in ("mapped_kwargs", "partial_kwargs"):
-                if "op_kwargs" not in v:
-                    op_kwargs: Optional[dict] = None
-                else:
-                    op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()}
-                v = {arg: cls._deserialize(value) for arg, value in v.items()}
-                if op_kwargs is not None:
-                    v["op_kwargs"] = op_kwargs
-            elif k == "mapped_op_kwargs":
+            elif k == "partial_kwargs":
                 v = {arg: cls._deserialize(value) for arg, value in v.items()}
             elif k in cls._decorated_fields or k not in op.get_serialized_fields():
                 v = cls._deserialize(v)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index fe9fc7c7e5..7d6a43e933 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1688,18 +1688,13 @@ def test_mapped_operator_serde():
         '_task_type': 'BashOperator',
         'downstream_task_ids': [],
         'mapped_kwargs': {
-            'bash_command': [
-                1,
-                2,
-                {"__type": "dict", "__var": {'a': 'b'}},
-            ]
+            "__type": "dict",
+            "__var": {'bash_command': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
         },
         'partial_kwargs': {
             'executor_config': {
                 '__type': 'dict',
-                '__var': {
-                    'dict': {"__type": "dict", "__var": {'sub': 'value'}},
-                },
+                '__var': {'dict': {"__type": "dict", "__var": {'sub': 'value'}}},
             },
         },
         'task_id': 'a',
@@ -1744,7 +1739,10 @@ def test_mapped_operator_xcomarg_serde():
         '_task_module': 'tests.test_utils.mock_operators',
         '_task_type': 'MockOperator',
         'downstream_task_ids': [],
-        'mapped_kwargs': {'arg2': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}},
+        'mapped_kwargs': {
+            "__type": "dict",
+            "__var": {'arg2': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}},
+        },
         'partial_kwargs': {},
         'task_id': 'task_2',
         'template_fields': ['arg1', 'arg2'],
@@ -1825,13 +1823,18 @@ def test_mapped_decorator_serde():
         'downstream_task_ids': [],
         'partial_kwargs': {
             'op_args': [],
-            'op_kwargs': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+            'op_kwargs': {
+                '__type': 'dict',
+                '__var': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+            },
             'retry_delay': {'__type': 'timedelta', '__var': 30.0},
         },
-        'mapped_kwargs': {},
         'mapped_op_kwargs': {
-            'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
-            'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}},
+            "__type": "dict",
+            "__var": {
+                'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
+                'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}},
+            },
         },
         'operator_extra_links': [],
         'ui_color': '#ffefeb',