You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by je...@apache.org on 2022/09/07 23:17:43 UTC

[airflow] branch main updated: Add ``@task.short_circuit`` TaskFlow decorator (#25752)

This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ebef9ed3fa Add ``@task.short_circuit`` TaskFlow decorator (#25752)
ebef9ed3fa is described below

commit ebef9ed3fa4a9a1e69b4405945e7cd939f499ee5
Author: Josh Fell <48...@users.noreply.github.com>
AuthorDate: Wed Sep 7 19:17:34 2022 -0400

    Add ``@task.short_circuit`` TaskFlow decorator (#25752)
---
 airflow/decorators/__init__.py                     |  3 +
 airflow/decorators/__init__.pyi                    | 20 ++++++
 airflow/decorators/short_circuit.py                | 83 ++++++++++++++++++++++
 ...rator.py => example_short_circuit_decorator.py} | 41 +++++------
 .../example_dags/example_short_circuit_operator.py |  4 --
 docs/apache-airflow/howto/operator/python.rst      | 33 +++++----
 tests/decorators/test_short_circuit.py             | 72 +++++++++++++++++++
 7 files changed, 215 insertions(+), 41 deletions(-)

diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py
index 6004a397e4..ad5d6431e5 100644
--- a/airflow/decorators/__init__.py
+++ b/airflow/decorators/__init__.py
@@ -22,6 +22,7 @@ from airflow.decorators.branch_python import branch_task
 from airflow.decorators.external_python import external_python_task
 from airflow.decorators.python import python_task
 from airflow.decorators.python_virtualenv import virtualenv_task
+from airflow.decorators.short_circuit import short_circuit_task
 from airflow.decorators.task_group import task_group
 from airflow.models.dag import dag
 from airflow.providers_manager import ProvidersManager
@@ -37,6 +38,7 @@ __all__ = [
     "virtualenv_task",
     "external_python_task",
     "branch_task",
+    "short_circuit_task",
 ]
 
 
@@ -47,6 +49,7 @@ class TaskDecoratorCollection:
     virtualenv = staticmethod(virtualenv_task)
     external_python = staticmethod(external_python_task)
     branch = staticmethod(branch_task)
+    short_circuit = staticmethod(short_circuit_task)
 
     __call__: Any = python  # Alias '@task' to '@task.python'.
 
diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index b5992cf513..e684860f4a 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -44,6 +44,7 @@ __all__ = [
     "virtualenv_task",
     "external_python_task",
     "branch_task",
+    "short_circuit_task",
 ]
 
 class TaskDecoratorCollection:
@@ -171,6 +172,25 @@ class TaskDecoratorCollection:
         """
     @overload
     def branch(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
+    @overload
+    def short_circuit(
+        self,
+        *,
+        multiple_outputs: Optional[bool] = None,
+        ignore_downstream_trigger_rules: bool = True,
+        **kwargs,
+    ) -> TaskDecorator:
+        """Create a decorator to wrap the decorated callable into a ShortCircuitOperator.
+
+        :param multiple_outputs: If set, function return value will be unrolled to multiple XCom values.
+            Dict will unroll to XCom values with keys as XCom keys. Defaults to False.
+        :param ignore_downstream_trigger_rules: If set to True, all downstream tasks from this operator task
+            will be skipped. This is the default behavior. If set to False, the direct, downstream task(s)
+            will be skipped but the ``trigger_rule`` defined for a other downstream tasks will be respected.
+            Defaults to True.
+        """
+    @overload
+    def short_circuit(self, python_callable: Callable[FParams, FReturn]) -> Task[FParams, FReturn]: ...
     # [START decorator_signature]
     def docker(
         self,
diff --git a/airflow/decorators/short_circuit.py b/airflow/decorators/short_circuit.py
new file mode 100644
index 0000000000..f3aec185b7
--- /dev/null
+++ b/airflow/decorators/short_circuit.py
@@ -0,0 +1,83 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Callable, Optional, Sequence
+
+from airflow.decorators.base import DecoratedOperator, TaskDecorator, task_decorator_factory
+from airflow.operators.python import ShortCircuitOperator
+
+
+class _ShortCircuitDecoratedOperator(DecoratedOperator, ShortCircuitOperator):
+    """
+    Wraps a Python callable and captures args/kwargs when called for execution.
+
+    :param python_callable: A reference to an object that is callable
+    :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+        in your function (templated)
+    :param op_args: a list of positional arguments that will get unpacked when
+        calling your callable (templated)
+    :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
+        multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
+    """
+
+    template_fields: Sequence[str] = ('op_args', 'op_kwargs')
+    template_fields_renderers = {"op_args": "py", "op_kwargs": "py"}
+
+    # since we won't mutate the arguments, we should just do the shallow copy
+    # there are some cases we can't deepcopy the objects (e.g protobuf).
+    shallow_copy_attrs: Sequence[str] = ('python_callable',)
+
+    custom_operator_name: str = '@task.short_circuit'
+
+    def __init__(self, *, python_callable, op_args, op_kwargs, **kwargs) -> None:
+        kwargs_to_upstream = {
+            "python_callable": python_callable,
+            "op_args": op_args,
+            "op_kwargs": op_kwargs,
+        }
+        super().__init__(
+            kwargs_to_upstream=kwargs_to_upstream,
+            python_callable=python_callable,
+            op_args=op_args,
+            op_kwargs=op_kwargs,
+            **kwargs,
+        )
+
+
+def short_circuit_task(
+    python_callable: Optional[Callable] = None,
+    multiple_outputs: Optional[bool] = None,
+    **kwargs,
+) -> TaskDecorator:
+    """Wraps a function into an ShortCircuitOperator.
+
+    Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+
+    This function is only used only used during type checking or auto-completion.
+
+    :param python_callable: Function to decorate
+    :param multiple_outputs: If set to True, the decorated function's return value will be unrolled to
+        multiple XCom values. Dict will unroll to XCom values with its keys as XCom keys. Defaults to False.
+
+    :meta private:
+    """
+    return task_decorator_factory(
+        python_callable=python_callable,
+        multiple_outputs=multiple_outputs,
+        decorated_operator_class=_ShortCircuitDecoratedOperator,
+        **kwargs,
+    )
diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_decorator.py
similarity index 66%
copy from airflow/example_dags/example_short_circuit_operator.py
copy to airflow/example_dags/example_short_circuit_decorator.py
index 2278de30e6..4e7e098624 100644
--- a/airflow/example_dags/example_short_circuit_operator.py
+++ b/airflow/example_dags/example_short_circuit_decorator.py
@@ -1,4 +1,3 @@
-#
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,37 +15,30 @@
 # specific language governing permissions and limitations
 # under the License.
 
-"""Example DAG demonstrating the usage of the ShortCircuitOperator."""
+"""Example DAG demonstrating the usage of the `@task.short_circuit()` TaskFlow decorator."""
 import pendulum
 
-from airflow import DAG
+from airflow.decorators import dag, task
 from airflow.models.baseoperator import chain
 from airflow.operators.empty import EmptyOperator
-from airflow.operators.python import ShortCircuitOperator
 from airflow.utils.trigger_rule import TriggerRule
 
-with DAG(
-    dag_id='example_short_circuit_operator',
-    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
-    catchup=False,
-    tags=['example'],
-) as dag:
-    # [START howto_operator_short_circuit]
-    cond_true = ShortCircuitOperator(
-        task_id='condition_is_True',
-        python_callable=lambda: True,
-    )
 
-    cond_false = ShortCircuitOperator(
-        task_id='condition_is_False',
-        python_callable=lambda: False,
-    )
+@dag(start_date=pendulum.datetime(2021, 1, 1, tz="UTC"), catchup=False, tags=['example'])
+def example_short_circuit_decorator():
+    # [START howto_operator_short_circuit]
+    @task.short_circuit()
+    def check_condition(condition):
+        return condition
 
     ds_true = [EmptyOperator(task_id='true_' + str(i)) for i in [1, 2]]
     ds_false = [EmptyOperator(task_id='false_' + str(i)) for i in [1, 2]]
 
-    chain(cond_true, *ds_true)
-    chain(cond_false, *ds_false)
+    condition_is_true = check_condition.override(task_id="condition_is_true")(condition=True)
+    condition_is_false = check_condition.override(task_id="condition_is_false")(condition=False)
+
+    chain(condition_is_true, *ds_true)
+    chain(condition_is_false, *ds_false)
     # [END howto_operator_short_circuit]
 
     # [START howto_operator_short_circuit_trigger_rules]
@@ -56,9 +48,12 @@ with DAG(
 
     task_7 = EmptyOperator(task_id="task_7", trigger_rule=TriggerRule.ALL_DONE)
 
-    short_circuit = ShortCircuitOperator(
-        task_id="short_circuit", ignore_downstream_trigger_rules=False, python_callable=lambda: False
+    short_circuit = check_condition.override(task_id="short_circuit", ignore_downstream_trigger_rules=False)(
+        condition=False
     )
 
     chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7)
     # [END howto_operator_short_circuit_trigger_rules]
+
+
+example_dag = example_short_circuit_decorator()
diff --git a/airflow/example_dags/example_short_circuit_operator.py b/airflow/example_dags/example_short_circuit_operator.py
index 2278de30e6..3fc9f1bd00 100644
--- a/airflow/example_dags/example_short_circuit_operator.py
+++ b/airflow/example_dags/example_short_circuit_operator.py
@@ -31,7 +31,6 @@ with DAG(
     catchup=False,
     tags=['example'],
 ) as dag:
-    # [START howto_operator_short_circuit]
     cond_true = ShortCircuitOperator(
         task_id='condition_is_True',
         python_callable=lambda: True,
@@ -47,9 +46,7 @@ with DAG(
 
     chain(cond_true, *ds_true)
     chain(cond_false, *ds_false)
-    # [END howto_operator_short_circuit]
 
-    # [START howto_operator_short_circuit_trigger_rules]
     [task_1, task_2, task_3, task_4, task_5, task_6] = [
         EmptyOperator(task_id=f"task_{i}") for i in range(1, 7)
     ]
@@ -61,4 +58,3 @@ with DAG(
     )
 
     chain(task_1, [task_2, short_circuit], [task_3, task_4], [task_5, task_6], task_7)
-    # [END howto_operator_short_circuit_trigger_rules]
diff --git a/docs/apache-airflow/howto/operator/python.rst b/docs/apache-airflow/howto/operator/python.rst
index b61ea77df1..7128a2a5e0 100644
--- a/docs/apache-airflow/howto/operator/python.rst
+++ b/docs/apache-airflow/howto/operator/python.rst
@@ -129,19 +129,24 @@ If you want the context related to datetime objects like ``data_interval_start``
 .. _howto/operator:ShortCircuitOperator:
 
 ShortCircuitOperator
-========================
+====================
+
+Use the ``@task.short_circuit`` decorator to control whether a pipeline continues
+if a condition is satisfied or a truthy value is obtained.
+
+.. warning::
+    The ``@task.short_circuit`` decorator is recommended over the classic :class:`~airflow.operators.python.ShortCircuitOperator`
+    to short-circuit pipelines via Python callables.
 
-Use the :class:`~airflow.operators.python.ShortCircuitOperator` to control whether a pipeline continues
-if a condition is satisfied or a truthy value is obtained. The evaluation of this condition and truthy value
-is done via the output of a ``python_callable``. If the ``python_callable`` returns True or a truthy value,
+The evaluation of this condition and truthy value
+is done via the output of the decorated function. If the decorated function returns True or a truthy value,
 the pipeline is allowed to continue and an :ref:`XCom <concepts:xcom>` of the output will be pushed. If the
 output is False or a falsy value, the pipeline will be short-circuited based on the configured
-short-circuiting (more on this later). In the example below, the tasks that follow the "condition_is_True"
-ShortCircuitOperator will execute while the tasks downstream of the "condition_is_False" ShortCircuitOperator
-will be skipped.
+short-circuiting (more on this later). In the example below, the tasks that follow the "condition_is_true"
+task will execute while the tasks downstream of the "condition_is_false" task will be skipped.
 
 
-.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_operator.py
+.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_decorator.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_short_circuit]
@@ -155,14 +160,14 @@ set to False, the direct downstream tasks are skipped but the specified ``trigge
 downstream tasks are respected. In this short-circuiting configuration, the operator assumes the direct
 downstream task(s) were purposely meant to be skipped but perhaps not other subsequent tasks. This
 configuration is especially useful if only *part* of a pipeline should be short-circuited rather than all
-tasks which follow the ShortCircuitOperator task.
+tasks which follow the short-circuiting task.
 
-In the example below, notice that the ShortCircuitOperator task is configured to respect downstream trigger
-rules. This means while the tasks that follow the "short_circuit" ShortCircuitOperator task will be skipped
-since the ``python_callable`` returns False, "task_7" will still execute as its set to execute when upstream
+In the example below, notice that the "short_circuit" task is configured to respect downstream trigger
+rules. This means while the tasks that follow the "short_circuit" task will be skipped
+since the decorated function returns False, "task_7" will still execute as its set to execute when upstream
 tasks have completed running regardless of status (i.e. the ``TriggerRule.ALL_DONE`` trigger rule).
 
-.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_operator.py
+.. exampleinclude:: /../../airflow/example_dags/example_short_circuit_decorator.py
     :language: python
     :dedent: 4
     :start-after: [START howto_operator_short_circuit_trigger_rules]
@@ -173,7 +178,7 @@ tasks have completed running regardless of status (i.e. the ``TriggerRule.ALL_DO
 Passing in arguments
 ^^^^^^^^^^^^^^^^^^^^
 
-Both the ``op_args`` and ``op_kwargs`` arguments can be used in same way as described for the PythonOperator.
+Pass extra arguments to the ``@task.short_circuit``-decorated function as you would with a normal Python function.
 
 
 Templating
diff --git a/tests/decorators/test_short_circuit.py b/tests/decorators/test_short_circuit.py
new file mode 100644
index 0000000000..c79da558de
--- /dev/null
+++ b/tests/decorators/test_short_circuit.py
@@ -0,0 +1,72 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from pendulum import datetime
+
+from airflow.decorators import task
+from airflow.utils.state import State
+from airflow.utils.trigger_rule import TriggerRule
+
+DEFAULT_DATE = datetime(2022, 8, 17)
+
+
+def test_short_circuit_decorator(dag_maker):
+    with dag_maker():
+
+        @task
+        def empty():
+            ...
+
+        @task.short_circuit()
+        def short_circuit(condition):
+            return condition
+
+        short_circuit_false = short_circuit.override(task_id="short_circuit_false")(condition=False)
+        task_1 = empty.override(task_id="task_1")()
+        short_circuit_false >> task_1
+
+        short_circuit_true = short_circuit.override(task_id="short_circuit_true")(condition=True)
+        task_2 = empty.override(task_id="task_2")()
+        short_circuit_true >> task_2
+
+        short_circuit_respect_trigger_rules = short_circuit.override(
+            task_id="short_circuit_respect_trigger_rules", ignore_downstream_trigger_rules=False
+        )(condition=False)
+        task_3 = empty.override(task_id="task_3")()
+        task_4 = empty.override(task_id="task_4")()
+        task_5 = empty.override(task_id="task_5", trigger_rule=TriggerRule.ALL_DONE)()
+        short_circuit_respect_trigger_rules >> [task_3, task_4] >> task_5
+
+    dr = dag_maker.create_dagrun()
+
+    for t in dag_maker.dag.tasks:
+        t.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)
+
+    task_state_mapping = {
+        "short_circuit_false": State.SUCCESS,
+        "task_1": State.SKIPPED,
+        "short_circuit_true": State.SUCCESS,
+        "task_2": State.SUCCESS,
+        "short_circuit_respect_trigger_rules": State.SUCCESS,
+        "task_3": State.SKIPPED,
+        "task_4": State.SKIPPED,
+        "task_5": State.SUCCESS,
+    }
+
+    tis = dr.get_task_instances()
+    for ti in tis:
+        assert ti.state == task_state_mapping[ti.task_id]