You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/06/22 07:48:57 UTC
[airflow] branch main updated: Remove special serde logic for mapped op_kwargs (#23860)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5877f45d65 Remove special serde logic for mapped op_kwargs (#23860)
5877f45d65 is described below
commit 5877f45d65d5aa864941efebd2040661b6f89cb1
Author: Tzu-ping Chung <ur...@gmail.com>
AuthorDate: Wed Jun 22 15:48:50 2022 +0800
Remove special serde logic for mapped op_kwargs (#23860)
Co-authored-by: Daniel Standish <15...@users.noreply.github.com>
---
airflow/decorators/base.py | 10 +--------
airflow/models/mappedoperator.py | 1 +
airflow/serialization/serialized_objects.py | 29 +++++----------------------
tests/serialization/test_dag_serialization.py | 29 +++++++++++++++------------
4 files changed, 23 insertions(+), 46 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 1b5b5b760b..2a2ce2da96 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -39,7 +39,7 @@ from typing import (
import attr
import typing_extensions
-from airflow.compat.functools import cache, cached_property
+from airflow.compat.functools import cached_property
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import DEFAULT_RETRIES, DEFAULT_RETRY_DELAY
from airflow.models.baseoperator import (
@@ -427,14 +427,6 @@ class DecoratedMappedOperator(MappedOperator):
def __hash__(self):
return id(self)
- @classmethod
- @cache
- def get_serialized_fields(cls):
- # The magic super() doesn't work here, so we use the explicit form.
- # Not using super(..., cls) to work around pyupgrade bug.
- sup = super(DecoratedMappedOperator, DecoratedMappedOperator)
- return sup.get_serialized_fields() | {"mapped_op_kwargs"}
-
def __attrs_post_init__(self):
# The magic super() doesn't work here, so we use the explicit form.
# Not using super(..., self) to work around pyupgrade bug.
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index ba7328d2f7..21a265e6e9 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -324,6 +324,7 @@ class MappedOperator(AbstractOperator):
"dag",
"deps",
"is_mapped",
+ "mapped_kwargs", # This is needed to be able to accept XComArg.
"subdag",
"task_group",
"upstream_task_ids",
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index 85286688f5..bd0430ba26 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -16,7 +16,7 @@
# under the License.
"""Serialized DAG and BaseOperator"""
-import contextlib
+
import datetime
import enum
import logging
@@ -593,6 +593,9 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
serialized_op = cls._serialize_node(op, include_deps=op.deps is MappedOperator.deps_for(BaseOperator))
+ # Handle mapped_kwargs and mapped_op_kwargs.
+ serialized_op[op._expansion_kwargs_attr] = cls._serialize(op._get_expansion_kwargs())
+
# Simplify partial_kwargs by comparing it to the most barebone object.
# Remove all entries that are simply default values.
serialized_partial = serialized_op["partial_kwargs"]
@@ -604,20 +607,6 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
if v == default:
del serialized_partial[k]
- # Simplify op_kwargs format. It must be a dict, so we flatten it.
- with contextlib.suppress(KeyError):
- op_kwargs = serialized_op["mapped_kwargs"]["op_kwargs"]
- assert op_kwargs[Encoding.TYPE] == DAT.DICT
- serialized_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
- with contextlib.suppress(KeyError):
- op_kwargs = serialized_op["partial_kwargs"]["op_kwargs"]
- assert op_kwargs[Encoding.TYPE] == DAT.DICT
- serialized_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
- with contextlib.suppress(KeyError):
- op_kwargs = serialized_op["mapped_op_kwargs"]
- assert op_kwargs[Encoding.TYPE] == DAT.DICT
- serialized_op["mapped_op_kwargs"] = op_kwargs[Encoding.VAR]
-
serialized_op["_is_mapped"] = True
return serialized_op
@@ -753,15 +742,7 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
v = cls._deserialize_deps(v)
elif k == "params":
v = cls._deserialize_params_dict(v)
- elif k in ("mapped_kwargs", "partial_kwargs"):
- if "op_kwargs" not in v:
- op_kwargs: Optional[dict] = None
- else:
- op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()}
- v = {arg: cls._deserialize(value) for arg, value in v.items()}
- if op_kwargs is not None:
- v["op_kwargs"] = op_kwargs
- elif k == "mapped_op_kwargs":
+ elif k == "partial_kwargs":
v = {arg: cls._deserialize(value) for arg, value in v.items()}
elif k in cls._decorated_fields or k not in op.get_serialized_fields():
v = cls._deserialize(v)
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index fe9fc7c7e5..7d6a43e933 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1688,18 +1688,13 @@ def test_mapped_operator_serde():
'_task_type': 'BashOperator',
'downstream_task_ids': [],
'mapped_kwargs': {
- 'bash_command': [
- 1,
- 2,
- {"__type": "dict", "__var": {'a': 'b'}},
- ]
+ "__type": "dict",
+ "__var": {'bash_command': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
},
'partial_kwargs': {
'executor_config': {
'__type': 'dict',
- '__var': {
- 'dict': {"__type": "dict", "__var": {'sub': 'value'}},
- },
+ '__var': {'dict': {"__type": "dict", "__var": {'sub': 'value'}}},
},
},
'task_id': 'a',
@@ -1744,7 +1739,10 @@ def test_mapped_operator_xcomarg_serde():
'_task_module': 'tests.test_utils.mock_operators',
'_task_type': 'MockOperator',
'downstream_task_ids': [],
- 'mapped_kwargs': {'arg2': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}},
+ 'mapped_kwargs': {
+ "__type": "dict",
+ "__var": {'arg2': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}}},
+ },
'partial_kwargs': {},
'task_id': 'task_2',
'template_fields': ['arg1', 'arg2'],
@@ -1825,13 +1823,18 @@ def test_mapped_decorator_serde():
'downstream_task_ids': [],
'partial_kwargs': {
'op_args': [],
- 'op_kwargs': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+ 'op_kwargs': {
+ '__type': 'dict',
+ '__var': {'arg1': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+ },
'retry_delay': {'__type': 'timedelta', '__var': 30.0},
},
- 'mapped_kwargs': {},
'mapped_op_kwargs': {
- 'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
- 'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}},
+ "__type": "dict",
+ "__var": {
+ 'arg2': {"__type": "dict", "__var": {'a': 1, 'b': 2}},
+ 'arg3': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'return_value'}},
+ },
},
'operator_extra_links': [],
'ui_color': '#ffefeb',