You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ur...@apache.org on 2022/02/10 07:08:46 UTC
[airflow] branch main updated: Rewrite decorated task mapping (#21328)
This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new fded2ca Rewrite decorated task mapping (#21328)
fded2ca is described below
commit fded2ca0b9c995737b401896b89e5c9fd7f24c91
Author: Tzu-ping Chung <tp...@astronomer.io>
AuthorDate: Thu Feb 10 15:07:50 2022 +0800
Rewrite decorated task mapping (#21328)
---
airflow/decorators/base.py | 92 ++++++++++++++++++++++-----
airflow/models/baseoperator.py | 55 +++-------------
airflow/models/taskinstance.py | 2 +-
airflow/serialization/serialized_objects.py | 27 ++++++--
tests/dags/test_mapped_taskflow.py | 31 +++++++++
tests/decorators/test_python.py | 69 ++++++++++++++++----
tests/jobs/test_backfill_job.py | 16 +++--
tests/serialization/test_dag_serialization.py | 53 +++++++++++++++
8 files changed, 259 insertions(+), 86 deletions(-)
diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py
index 9cf423f..53a12c6 100644
--- a/airflow/decorators/base.py
+++ b/airflow/decorators/base.py
@@ -280,30 +280,88 @@ class _TaskDecorator(Generic[Function, OperatorSubclass]):
names = ", ".join(repr(n) for n in unknown_args)
raise TypeError(f'{funcname} got unexpected keyword arguments {names}')
- def map(
- self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
- ) -> XComArg:
+ def map(self, *args, **kwargs) -> XComArg:
self._validate_arg_names("map", kwargs)
- dag = dag or DagContext.get_current_dag()
- task_group = task_group or TaskGroupContext.get_current_task_group(dag)
- task_id = get_unique_task_id(self.kwargs['task_id'], dag, task_group)
- operator = MappedOperator.from_decorator(
- decorator=self,
+ partial_kwargs = self.kwargs.copy()
+ dag = partial_kwargs.pop("dag", DagContext.get_current_dag())
+ task_group = partial_kwargs.pop("task_group", TaskGroupContext.get_current_task_group(dag))
+ task_id = get_unique_task_id(partial_kwargs.pop("task_id"), dag, task_group)
+
+ # Unfortunately attrs's type hinting support does not work well with
+ # subclassing; it complains that arguments forwarded to the superclass
+ # are "unexpected" (they are fine at runtime).
+ operator = cast(Any, DecoratedMappedOperator)(
+ operator_class=self.operator_class,
+ partial_kwargs=partial_kwargs,
+ mapped_kwargs={},
+ task_id=task_id,
dag=dag,
task_group=task_group,
- task_id=task_id,
- mapped_kwargs=kwargs,
+ deps=MappedOperator._deps(self.operator_class.deps),
+ multiple_outputs=self.multiple_outputs,
+ python_callable=self.function,
)
+
+ operator.mapped_kwargs["op_args"] = list(args)
+ operator.mapped_kwargs["op_kwargs"] = kwargs
+
+ for arg in itertools.chain(args, kwargs.values()):
+ XComArg.apply_upstream_relationship(operator, arg)
return XComArg(operator=operator)
- def partial(
- self, *, dag: Optional["DAG"] = None, task_group: Optional["TaskGroup"] = None, **kwargs
- ) -> "_TaskDecorator[Function, OperatorSubclass]":
- self._validate_arg_names("partial", kwargs, {'task_id'})
- partial_kwargs = self.kwargs.copy()
- partial_kwargs.update(kwargs)
- return attr.evolve(self, kwargs=partial_kwargs)
+ def partial(self, *args, **kwargs) -> "_TaskDecorator[Function, OperatorSubclass]":
+ self._validate_arg_names("partial", kwargs)
+
+ op_args = self.kwargs.get("op_args", [])
+ op_args.extend(args)
+
+ op_kwargs = self.kwargs.get("op_kwargs", {})
+ op_kwargs = _merge_kwargs(op_kwargs, kwargs, fail_reason="duplicate partial")
+
+ return attr.evolve(self, kwargs={**self.kwargs, "op_args": op_args, "op_kwargs": op_kwargs})
+
+
+def _merge_kwargs(
+ kwargs1: Dict[str, XComArg],
+ kwargs2: Dict[str, XComArg],
+ *,
+ fail_reason: str,
+) -> Dict[str, XComArg]:
+ duplicated_keys = set(kwargs1).intersection(kwargs2)
+ if len(duplicated_keys) == 1:
+ raise TypeError(f"{fail_reason} argument: {duplicated_keys.pop()}")
+ elif duplicated_keys:
+ duplicated_keys_display = ", ".join(sorted(duplicated_keys))
+ raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
+ return {**kwargs1, **kwargs2}
+
+
+@attr.define(kw_only=True)
+class DecoratedMappedOperator(MappedOperator):
+ """MappedOperator implementation for @task-decorated task function."""
+
+ multiple_outputs: bool
+ python_callable: Callable
+
+ def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
+ assert not isinstance(self.operator_class, str)
+ op_args = self.partial_kwargs.pop("op_args", []) + self.mapped_kwargs.pop("op_args", [])
+ op_kwargs = _merge_kwargs(
+ self.partial_kwargs.pop("op_kwargs", {}),
+ self.mapped_kwargs.pop("op_kwargs", {}),
+ fail_reason="mapping already partial",
+ )
+ return self.operator_class(
+ dag=dag,
+ task_id=self.task_id,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ multiple_outputs=self.multiple_outputs,
+ python_callable=self.python_callable,
+ **self.partial_kwargs,
+ **self.mapped_kwargs,
+ )
class Task(Generic[Function]):
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 8f96153..35a0fbb 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -82,7 +82,6 @@ from airflow.utils.trigger_rule import TriggerRule
from airflow.utils.weight_rule import WeightRule
if TYPE_CHECKING:
- from airflow.decorators.base import _TaskDecorator
from airflow.models.dag import DAG
from airflow.utils.task_group import TaskGroup
@@ -243,7 +242,7 @@ class BaseOperatorMeta(abc.ABCMeta):
return new_cls
# The class level partial function. This is what handles the actual mapping
- def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs):
+ def partial(cls, *, task_id: str, dag: Optional["DAG"] = None, **kwargs) -> "MappedOperator":
operator_class = cast("Type[BaseOperator]", cls)
# Validate that the args we passed are known -- at call/DAG parse time, not run time!
_validate_kwarg_names_for_mapping(operator_class, "partial", kwargs)
@@ -1632,7 +1631,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
dag._remove_task(operator.task_id)
operator_init_kwargs: dict = operator._BaseOperator__init_kwargs # type: ignore
- return MappedOperator(
+ return cls(
operator_class=type(operator),
task_id=operator.task_id,
task_group=task_group,
@@ -1649,37 +1648,6 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
)
@classmethod
- def from_decorator(
- cls,
- *,
- decorator: "_TaskDecorator",
- dag: Optional["DAG"],
- task_group: Optional["TaskGroup"],
- task_id: str,
- mapped_kwargs: Dict[str, Any],
- ) -> "MappedOperator":
- """Create a mapped operator from a task decorator.
-
- Different from ``from_operator``, this DOES NOT validate ``mapped_kwargs``.
- The task decorator calling this should be responsible for validation.
- """
- from airflow.models.xcom_arg import XComArg
-
- operator = MappedOperator(
- operator_class=decorator.operator_class,
- partial_kwargs=decorator.kwargs,
- mapped_kwargs={},
- task_id=task_id,
- dag=dag,
- task_group=task_group,
- deps=cls._deps(decorator.operator_class.deps),
- )
- operator.mapped_kwargs.update(mapped_kwargs)
- for arg in mapped_kwargs.values():
- XComArg.apply_upstream_relationship(operator, arg)
- return operator
-
- @classmethod
def _deps(cls, deps: Iterable[BaseTIDep]):
if deps is BaseOperator.deps:
return cls.DEFAULT_DEPS
@@ -1749,7 +1717,7 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
@classmethod
def get_serialized_fields(cls):
if cls.__serialized_fields is None:
- fields_dict = attr.fields_dict(cls)
+ fields_dict = attr.fields_dict(MappedOperator)
cls.__serialized_fields = frozenset(
fields_dict.keys()
- {
@@ -1902,22 +1870,17 @@ class MappedOperator(Operator, LoggingMixin, DAGNode):
return ret
- def unmap(self) -> BaseOperator:
- """Get the "normal" Operator after applying the current mapping"""
+ def create_unmapped_operator(self, dag: "DAG") -> BaseOperator:
assert not isinstance(self.operator_class, str)
+ return self.operator_class(dag=dag, task_id=self.task_id, **self.partial_kwargs, **self.mapped_kwargs)
+ def unmap(self) -> BaseOperator:
+ """Get the "normal" Operator after applying the current mapping"""
dag = self.get_dag()
if not dag:
- raise RuntimeError("Cannot unmapp a task unless it has a dag")
-
- args = {
- **self.partial_kwargs,
- **self.mapped_kwargs,
- }
+ raise RuntimeError("Cannot unmap a task unless it has a DAG")
dag._remove_task(self.task_id)
- task = self.operator_class(task_id=self.task_id, dag=self.dag, **args)
-
- return task
+ return self.create_unmapped_operator(dag)
# TODO: Deprecate for Airflow 3.0
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 4996b9a..f10032d 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -1713,7 +1713,7 @@ class TaskInstance(Base, LoggingMixin):
test_mode: Optional[bool] = None,
force_fail: bool = False,
error_file: Optional[str] = None,
- session=NEW_SESSION,
+ session: Session = NEW_SESSION,
) -> None:
"""Handle Failure for the TaskInstance"""
if test_mode is None:
diff --git a/airflow/serialization/serialized_objects.py b/airflow/serialization/serialized_objects.py
index d6abda7..017f227 100644
--- a/airflow/serialization/serialized_objects.py
+++ b/airflow/serialization/serialized_objects.py
@@ -16,6 +16,7 @@
# under the License.
"""Serialized DAG and BaseOperator"""
+import contextlib
import datetime
import enum
import logging
@@ -168,7 +169,7 @@ def _decode_timetable(var: Dict[str, Any]) -> Timetable:
return timetable_class.deserialize(var[Encoding.VAR])
-class _XcomRef(NamedTuple):
+class _XComRef(NamedTuple):
"""
Used to store info needed to create XComArg when deserializing MappedOperator.
@@ -497,8 +498,8 @@ class BaseSerialization:
return {"key": arg.key, "task_id": arg.operator.task_id}
@classmethod
- def _deserialize_xcomref(cls, encoded: dict) -> _XcomRef:
- return _XcomRef(key=encoded['key'], task_id=encoded['task_id'])
+ def _deserialize_xcomref(cls, encoded: dict) -> _XComRef:
+ return _XComRef(key=encoded['key'], task_id=encoded['task_id'])
class DependencyDetector:
@@ -566,9 +567,19 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
@classmethod
def serialize_mapped_operator(cls, op: MappedOperator) -> Dict[str, Any]:
-
stock_deps = op.deps is MappedOperator.DEFAULT_DEPS
serialize_op = cls._serialize_node(op, include_deps=not stock_deps)
+
+ # Simplify op_kwargs format. It must be a dict, so we flatten it.
+ with contextlib.suppress(KeyError):
+ op_kwargs = serialize_op["mapped_kwargs"]["op_kwargs"]
+ assert op_kwargs[Encoding.TYPE] == DAT.DICT
+ serialize_op["mapped_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
+ with contextlib.suppress(KeyError):
+ op_kwargs = serialize_op["partial_kwargs"]["op_kwargs"]
+ assert op_kwargs[Encoding.TYPE] == DAT.DICT
+ serialize_op["partial_kwargs"]["op_kwargs"] = op_kwargs[Encoding.VAR]
+
# It must be a class at this point for it to work, not a string
assert isinstance(op.operator_class, type)
serialize_op['_task_type'] = op.operator_class.__name__
@@ -715,7 +726,13 @@ class SerializedBaseOperator(BaseOperator, BaseSerialization):
elif k == "params":
v = cls._deserialize_params_dict(v)
elif k in ("mapped_kwargs", "partial_kwargs"):
+ if "op_kwargs" not in v:
+ op_kwargs: Optional[dict] = None
+ else:
+ op_kwargs = {arg: cls._deserialize(value) for arg, value in v.pop("op_kwargs").items()}
v = {arg: cls._deserialize(value) for arg, value in v.items()}
+ if op_kwargs is not None:
+ v["op_kwargs"] = op_kwargs
elif k in cls._decorated_fields or k not in op.get_serialized_fields():
v = cls._deserialize(v)
# else use v as it is
@@ -1002,7 +1019,7 @@ class SerializedDAG(DAG, BaseSerialization):
if isinstance(task, MappedOperator):
for d in (task.mapped_kwargs, task.partial_kwargs):
for k, v in d.items():
- if not isinstance(v, _XcomRef):
+ if not isinstance(v, _XComRef):
continue
d[k] = XComArg(operator=dag.get_task(v.task_id), key=v.key)
diff --git a/tests/dags/test_mapped_taskflow.py b/tests/dags/test_mapped_taskflow.py
new file mode 100644
index 0000000..f21a9a5
--- /dev/null
+++ b/tests/dags/test_mapped_taskflow.py
@@ -0,0 +1,31 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from airflow import DAG
+from airflow.utils.dates import days_ago
+
+with DAG(dag_id='test_mapped_taskflow', start_date=days_ago(2)) as dag:
+
+ @dag.task
+ def make_list():
+ return [1, 2, {'a': 'b'}]
+
+ @dag.task
+ def consumer(value):
+ print(repr(value))
+
+ consumer.map(value=make_list())
diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py
index 0c93b49..ee94fde 100644
--- a/tests/decorators/test_python.py
+++ b/tests/decorators/test_python.py
@@ -17,7 +17,7 @@
# under the License.
import sys
from collections import namedtuple
-from datetime import date, timedelta
+from datetime import date, datetime, timedelta
from typing import Dict # noqa: F401 # This is used by annotation tests.
from typing import Tuple
@@ -490,7 +490,7 @@ def test_mapped_decorator() -> None:
assert isinstance(doubled_0, XComArg)
assert isinstance(doubled_0.operator, MappedOperator)
assert doubled_0.operator.task_id == "double"
- assert doubled_0.operator.mapped_kwargs == {"number": literal}
+ assert doubled_0.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}}
assert doubled_1.operator.task_id == "double__1"
@@ -514,25 +514,68 @@ def test_partial_mapped_decorator() -> None:
def product(number: int, multiple: int):
return number * multiple
+ literal = [1, 2, 3]
+
with DAG('test_dag', start_date=DEFAULT_DATE) as dag:
- literal = [1, 2, 3]
- quadrupled = product.partial(task_id='times_4', multiple=3).map(number=literal)
+ quadrupled = product.partial(multiple=3).map(number=literal)
doubled = product.partial(multiple=2).map(number=literal)
trippled = product.partial(multiple=3).map(number=literal)
- product.partial(multiple=2)
+ product.partial(multiple=2) # No operator is actually created.
+
+ assert dag.task_dict == {
+ "product": quadrupled.operator,
+ "product__1": doubled.operator,
+ "product__2": trippled.operator,
+ }
assert isinstance(doubled, XComArg)
assert isinstance(doubled.operator, MappedOperator)
- assert doubled.operator.task_id == "product"
- assert doubled.operator.mapped_kwargs == {"number": literal}
- assert doubled.operator.partial_kwargs == {"task_id": "product", "multiple": 2}
+ assert doubled.operator.mapped_kwargs == {"op_args": [], "op_kwargs": {"number": literal}}
+ assert doubled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 2}}
- assert trippled.operator.task_id == "product__1"
- assert trippled.operator.partial_kwargs == {"task_id": "product", "multiple": 3}
-
- assert quadrupled.operator.task_id == "times_4"
+ assert isinstance(trippled.operator, MappedOperator) # For type-checking on partial_kwargs.
+ assert trippled.operator.partial_kwargs == {"op_args": [], "op_kwargs": {"multiple": 3}}
assert doubled.operator is not trippled.operator
- assert [quadrupled.operator, doubled.operator, trippled.operator] == dag.tasks
+
+def test_mapped_decorator_unmap_merge_op_kwargs():
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+
+ @task_decorator
+ def task1():
+ ...
+
+ @task_decorator
+ def task2(arg1, arg2):
+ ...
+
+ task2.partial(arg1=1).map(arg2=task1())
+
+ unmapped = dag.get_task("task2").unmap()
+ assert set(unmapped.op_kwargs) == {"arg1", "arg2"}
+
+
+def test_mapped_decorator_unmap_converts_partial_kwargs():
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+
+ @task_decorator
+ def task1(arg):
+ ...
+
+ @task_decorator(retry_delay=30)
+ def task2(arg1, arg2):
+ ...
+
+ task2.partial(arg1=1).map(arg2=task1.map(arg=[1, 2]))
+
+ # Arguments to the task decorator are stored in partial_kwargs, and
+ # converted into their intended form after the task is unmapped.
+ mapped_task2 = dag.get_task("task2")
+ assert mapped_task2.partial_kwargs["retry_delay"] == 30
+ assert mapped_task2.unmap().retry_delay == timedelta(seconds=30)
+
+ mapped_task1 = dag.get_task("task1")
+ assert "retry_delay" not in mapped_task1.partial_kwargs
+ mapped_task1.unmap().retry_delay == timedelta(seconds=300) # Operator default.
diff --git a/tests/jobs/test_backfill_job.py b/tests/jobs/test_backfill_job.py
index 0878f63..40593d5 100644
--- a/tests/jobs/test_backfill_job.py
+++ b/tests/jobs/test_backfill_job.py
@@ -47,7 +47,13 @@ from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.types import DagRunType
from tests.models import TEST_DAGS_FOLDER
-from tests.test_utils.db import clear_db_dags, clear_db_pools, clear_db_runs, set_default_pool_slots
+from tests.test_utils.db import (
+ clear_db_dags,
+ clear_db_pools,
+ clear_db_runs,
+ clear_db_xcom,
+ set_default_pool_slots,
+)
from tests.test_utils.mock_executor import MockExecutor
from tests.test_utils.timetables import cron_timetable
@@ -66,6 +72,7 @@ class TestBackfillJob:
def clean_db():
clear_db_dags()
clear_db_runs()
+ clear_db_xcom()
clear_db_pools()
@pytest.fixture(autouse=True)
@@ -1512,13 +1519,14 @@ class TestBackfillJob:
job.run()
assert executor.job_id is not None
- def test_mapped_dag(self, dag_maker):
+ @pytest.mark.parametrize("dag_id", ["test_mapped_classic", "test_mapped_taskflow"])
+ def test_mapped_dag(self, dag_id):
"""End-to-end test of a simple mapped dag"""
# Use SequentialExecutor for more predictable test behaviour
from airflow.executors.sequential_executor import SequentialExecutor
- self.dagbag.process_file(str(TEST_DAGS_FOLDER / 'test_mapped_classic.py'))
- dag = self.dagbag.get_dag('test_mapped_classic')
+ self.dagbag.process_file(str(TEST_DAGS_FOLDER / f'{dag_id}.py'))
+ dag = self.dagbag.get_dag(dag_id)
# This needs a real executor to run, so that the `make_list` task can write out the TaskMap
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 447b173..1e8d510 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -1654,6 +1654,59 @@ def test_mapped_operator_xcomarg_serde():
assert xcom_arg.operator is serialized_dag.task_dict['op1']
+def test_mapped_decorator_serde():
+ from airflow.decorators import task
+ from airflow.models.xcom_arg import XComArg
+ from airflow.serialization.serialized_objects import _XComRef
+
+ with DAG("test-dag", start_date=datetime(2020, 1, 1)) as dag:
+ op1 = BaseOperator(task_id="op1")
+ xcomarg = XComArg(op1, "my_key")
+
+ @task(retry_delay=30)
+ def x(arg1, arg2, arg3, arg4):
+ print(arg1, arg2, arg3, arg4)
+
+ x.partial("foo", arg3=[1, 2, {"a": "b"}]).map({"a": 1, "b": 2}, arg4=xcomarg)
+
+ original = dag.get_task("x")
+
+ serialized = SerializedBaseOperator._serialize(original)
+ assert serialized == {
+ '_is_dummy': False,
+ '_is_mapped': True,
+ '_task_module': 'airflow.decorators.python',
+ '_task_type': '_PythonDecoratedOperator',
+ 'downstream_task_ids': [],
+ 'partial_kwargs': {
+ 'op_args': ["foo"],
+ 'op_kwargs': {'arg3': [1, 2, {"__type": "dict", "__var": {'a': 'b'}}]},
+ 'retry_delay': 30,
+ },
+ 'mapped_kwargs': {
+ 'op_args': [{"__type": "dict", "__var": {'a': 1, 'b': 2}}],
+ 'op_kwargs': {'arg4': {'__type': 'xcomref', '__var': {'task_id': 'op1', 'key': 'my_key'}}},
+ },
+ 'task_id': 'x',
+ 'template_ext': [],
+ 'template_fields': ['op_args', 'op_kwargs'],
+ }
+
+ deserialized = SerializedBaseOperator.deserialize_operator(serialized)
+ assert isinstance(deserialized, MappedOperator)
+ assert deserialized.deps is MappedOperator.DEFAULT_DEPS
+
+ assert deserialized.mapped_kwargs == {
+ "op_args": [{"a": 1, "b": 2}],
+ "op_kwargs": {"arg4": _XComRef("op1", "my_key")},
+ }
+ assert deserialized.partial_kwargs == {
+ "retry_delay": 30,
+ "op_args": ["foo"],
+ "op_kwargs": {"arg3": [1, 2, {"a": "b"}]},
+ }
+
+
def test_mapped_task_group_serde():
execution_date = datetime(2020, 1, 1)