You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/11/07 02:06:27 UTC

[airflow] branch main updated: Adding sensor decorator (#22562)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cfd63df786 Adding sensor decorator (#22562)
cfd63df786 is described below

commit cfd63df786e0c40723968cb8078f808ca9d39688
Author: Mingshi <80...@users.noreply.github.com>
AuthorDate: Sun Nov 6 18:06:19 2022 -0800

    Adding sensor decorator (#22562)
    
    Co-authored-by: mingshi <mi...@coinbase.com>
---
 airflow/decorators/__init__.py                   |   3 +
 airflow/decorators/__init__.pyi                  |  37 ++++++
 airflow/decorators/sensor.py                     |  74 ++++++++++++
 airflow/example_dags/example_sensor_decorator.py |  67 +++++++++++
 airflow/sensors/python.py                        |   6 +-
 docs/apache-airflow/tutorial/taskflow.rst        |  16 +++
 tests/decorators/test_sensor.py                  | 146 +++++++++++++++++++++++
 7 files changed, 346 insertions(+), 3 deletions(-)

diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py
index 2485b6645a..af478314e1 100644
--- a/airflow/decorators/__init__.py
+++ b/airflow/decorators/__init__.py
@@ -23,6 +23,7 @@ from airflow.decorators.branch_python import branch_task
 from airflow.decorators.external_python import external_python_task
 from airflow.decorators.python import python_task
 from airflow.decorators.python_virtualenv import virtualenv_task
+from airflow.decorators.sensor import sensor_task
 from airflow.decorators.short_circuit import short_circuit_task
 from airflow.decorators.task_group import task_group
 from airflow.models.dag import dag
@@ -40,6 +41,7 @@ __all__ = [
     "external_python_task",
     "branch_task",
     "short_circuit_task",
+    "sensor_task",
 ]
 
 
@@ -51,6 +53,7 @@ class TaskDecoratorCollection:
     external_python = staticmethod(external_python_task)
     branch = staticmethod(branch_task)
     short_circuit = staticmethod(short_circuit_task)
+    sensor = staticmethod(sensor_task)
 
     __call__: Any = python  # Alias '@task' to '@task.python'.
 
diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index fad075ca4f..fd17efa174 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -29,6 +29,7 @@ from airflow.decorators.branch_python import branch_task
 from airflow.decorators.external_python import external_python_task
 from airflow.decorators.python import python_task
 from airflow.decorators.python_virtualenv import virtualenv_task
+from airflow.decorators.sensor import sensor_task
 from airflow.decorators.task_group import task_group
 from airflow.kubernetes.secret import Secret
 from airflow.models.dag import dag
@@ -45,6 +46,7 @@ __all__ = [
     "external_python_task",
     "branch_task",
     "short_circuit_task",
+    "sensor_task",
 ]
 
 class TaskDecoratorCollection:
@@ -410,5 +412,40 @@ class TaskDecoratorCollection:
             of the target ConfigMap's Data field will represent the key-value
             pairs as environment variables. Extends env_from.
         """
+    @overload
+    def sensor(
+        self,
+        *,
+        poke_interval: float = ...,
+        timeout: float = ...,
+        soft_fail: bool = False,
+        mode: str = ...,
+        exponential_backoff: bool = False,
+        **kwargs,
+    ) -> TaskDecorator:
+        """
+        Wraps a Python function into a sensor operator.
+
+        :param poke_interval: Time in seconds that the job should wait in
+            between each try
+        :param timeout: Time, in seconds before the task times out and fails.
+        :param soft_fail: Set to true to mark the task as SKIPPED on failure
+        :param mode: How the sensor operates.
+            Options are: ``{ poke | reschedule }``, default is ``poke``.
+            When set to ``poke`` the sensor is taking up a worker slot for its
+            whole execution time and sleeps between pokes. Use this mode if the
+            expected runtime of the sensor is short or if a short poke interval
+            is required. Note that the sensor will hold onto a worker slot and
+            a pool slot for the duration of the sensor's runtime in this mode.
+            When set to ``reschedule`` the sensor task frees the worker slot when
+            the criteria is not yet met and it's rescheduled at a later time. Use
+            this mode if the time before the criteria is met is expected to be
+            quite long. The poke interval should be more than one minute to
+            prevent too much load on the scheduler.
+        :param exponential_backoff: allow progressive longer waits between
+            pokes by using exponential backoff algorithm
+        """
+    @overload
+    def sensor(self, python_callable: Optional[FParams, FReturn] = None) -> Task[FParams, FReturn]: ...
 
 task: TaskDecoratorCollection
diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py
new file mode 100644
index 0000000000..c7d1f1181b
--- /dev/null
+++ b/airflow/decorators/sensor.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Callable, Sequence
+
+from airflow.decorators.base import TaskDecorator, get_unique_task_id, task_decorator_factory
+from airflow.models.taskinstance import Context
+from airflow.sensors.base import PokeReturnValue
+from airflow.sensors.python import PythonSensor
+
+
+class DecoratedSensorOperator(PythonSensor):
+    """
+    Wraps a Python callable and captures args/kwargs when called for execution.
+    :param python_callable: A reference to an object that is callable
+    :param task_id: task Id
+    :param op_args: a list of positional arguments that will get unpacked when
+        calling your callable (templated)
+    :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+        in your function (templated)
+    :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments
+        that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the
+        PythonOperator). This gives a user the option to upstream kwargs as needed.
+    """
+
+    template_fields: Sequence[str] = ('op_args', 'op_kwargs')
+    template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"}
+
+    # since we won't mutate the arguments, we should just do the shallow copy
+    # there are some cases we can't deepcopy the objects (e.g protobuf).
+    shallow_copy_attrs: Sequence[str] = ('python_callable',)
+
+    def __init__(
+        self,
+        *,
+        task_id: str,
+        **kwargs,
+    ) -> None:
+        kwargs.pop('multiple_outputs')
+        kwargs['task_id'] = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group'))
+        super().__init__(**kwargs)
+
+    def poke(self, context: Context) -> PokeReturnValue:
+        return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+def sensor_task(python_callable: Callable | None = None, **kwargs) -> TaskDecorator:
+    """
+    Wraps a function into an Airflow operator.
+    Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+    :param python_callable: Function to decorate
+    """
+    return task_decorator_factory(
+        python_callable=python_callable,
+        multiple_outputs=False,
+        decorated_operator_class=DecoratedSensorOperator,
+        **kwargs,
+    )
diff --git a/airflow/example_dags/example_sensor_decorator.py b/airflow/example_dags/example_sensor_decorator.py
new file mode 100644
index 0000000000..08908589e0
--- /dev/null
+++ b/airflow/example_dags/example_sensor_decorator.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG demonstrating the usage of the sensor decorator."""
+
+from __future__ import annotations
+
+# [START tutorial]
+# [START import_module]
+import pendulum
+
+from airflow.decorators import dag, task
+from airflow.sensors.base import PokeReturnValue
+
+# [END import_module]
+
+
+# [START instantiate_dag]
+@dag(
+    schedule_interval=None,
+    start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+    catchup=False,
+    tags=['example'],
+)
+def example_sensor_decorator():
+    # [END instantiate_dag]
+
+    # [START wait_function]
+    # Using a sensor operator to wait for the upstream data to be ready.
+    @task.sensor(poke_interval=60, timeout=3600, mode="reschedule")
+    def wait_for_upstream() -> PokeReturnValue:
+        return PokeReturnValue(is_done=True, xcom_value="xcom_value")
+
+    # [END wait_function]
+
+    # [START dummy_function]
+    @task
+    def dummy_operator() -> None:
+        pass
+
+    # [END dummy_function]
+
+    # [START main_flow]
+    wait_for_upstream() >> dummy_operator()
+    # [END main_flow]
+
+
+# [START dag_invocation]
+tutorial_etl_dag = example_sensor_decorator()
+# [END dag_invocation]
+
+# [END tutorial]
diff --git a/airflow/sensors/python.py b/airflow/sensors/python.py
index d9241d8952..37143dbbb0 100644
--- a/airflow/sensors/python.py
+++ b/airflow/sensors/python.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from typing import Any, Callable, Mapping, Sequence
 
-from airflow.sensors.base import BaseSensorOperator
+from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
 from airflow.utils.context import Context, context_merge
 from airflow.utils.operator_helpers import determine_kwargs
 
@@ -65,10 +65,10 @@ class PythonSensor(BaseSensorOperator):
         self.op_kwargs = op_kwargs or {}
         self.templates_dict = templates_dict
 
-    def poke(self, context: Context) -> bool:
+    def poke(self, context: Context) -> PokeReturnValue:
         context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
         self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, context)
 
         self.log.info("Poking callable: %s", str(self.python_callable))
         return_value = self.python_callable(*self.op_args, **self.op_kwargs)
-        return bool(return_value)
+        return PokeReturnValue(bool(return_value))
diff --git a/docs/apache-airflow/tutorial/taskflow.rst b/docs/apache-airflow/tutorial/taskflow.rst
index 63a3cfccd8..9db581200e 100644
--- a/docs/apache-airflow/tutorial/taskflow.rst
+++ b/docs/apache-airflow/tutorial/taskflow.rst
@@ -359,6 +359,22 @@ Notes on using the operator:
     You should upgrade to Airflow 2.4 or above in order to use it.
 
 
+Using the TaskFlow API with Sensor operators
+--------------------------------------------
+You can apply the ``@task.sensor`` decorator to convert a regular Python function to an instance of the
+BaseSensorOperator class. The Python function implements the poke logic and returns an instance of
+the ``PokeReturnValue`` class as the ``poke()`` method in the BaseSensorOperator does. The ``PokeReturnValue`` is
+a new feature in Airflow 2.3 that allows a sensor operator to push an XCom value as described in
+section "Having sensors return XOM values" of :doc:`apache-airflow-providers:howto/create-update-providers`.
+
+.. _taskflow/task_sensor_example:
+
+.. exampleinclude:: /../../airflow/example_dags/example_sensor_decorator.py
+    :language: python
+    :start-after: [START tutorial]
+    :end-before: [END tutorial]
+
+
 Multiple outputs inference
 --------------------------
 Tasks can also infer multiple outputs by using dict Python typing.
diff --git a/tests/decorators/test_sensor.py b/tests/decorators/test_sensor.py
new file mode 100644
index 0000000000..d58fb486aa
--- /dev/null
+++ b/tests/decorators/test_sensor.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.decorators import task
+from airflow.exceptions import AirflowSensorTimeout
+from airflow.models import XCom
+from airflow.sensors.base import PokeReturnValue
+from airflow.utils.state import State
+
+
+class TestSensorDecorator:
+    def test_sensor_fails_on_none_python_callable(self, dag_maker):
+        not_callable = {}
+        with pytest.raises(TypeError):
+            task.sensor(not_callable)
+
+    def test_basic_sensor_success(self, dag_maker):
+        sensor_xcom_value = "xcom_value"
+
+        @task.sensor
+        def sensor_f():
+            return PokeReturnValue(is_done=True, xcom_value=sensor_xcom_value)
+
+        @task
+        def dummy_f():
+            pass
+
+        with dag_maker():
+            sf = sensor_f()
+            df = dummy_f()
+            sf >> df
+
+        dr = dag_maker.create_dagrun()
+        sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+        tis = dr.get_task_instances()
+        assert len(tis) == 2
+        for ti in tis:
+            if ti.task_id == "sensor_f":
+                assert ti.state == State.SUCCESS
+            if ti.task_id == "dummy_f":
+                assert ti.state == State.NONE
+        actual_xcom_value = XCom.get_one(
+            key="return_value", task_id="sensor_f", dag_id=dr.dag_id, run_id=dr.run_id
+        )
+        assert actual_xcom_value == sensor_xcom_value
+
+    def test_basic_sensor_failure(self, dag_maker):
+        @task.sensor(timeout=0)
+        def sensor_f():
+            return PokeReturnValue(is_done=False, xcom_value="xcom_value")
+
+        @task
+        def dummy_f():
+            pass
+
+        with dag_maker():
+            sf = sensor_f()
+            df = dummy_f()
+            sf >> df
+
+        dr = dag_maker.create_dagrun()
+        with pytest.raises(AirflowSensorTimeout):
+            sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+
+        tis = dr.get_task_instances()
+        assert len(tis) == 2
+        for ti in tis:
+            if ti.task_id == "sensor_f":
+                assert ti.state == State.FAILED
+            if ti.task_id == "dummy_f":
+                assert ti.state == State.NONE
+
+    def test_basic_sensor_soft_fail(self, dag_maker):
+        @task.sensor(timeout=0, soft_fail=True)
+        def sensor_f():
+            return PokeReturnValue(is_done=False, xcom_value="xcom_value")
+
+        @task
+        def dummy_f():
+            pass
+
+        with dag_maker():
+            sf = sensor_f()
+            df = dummy_f()
+            sf >> df
+
+        dr = dag_maker.create_dagrun()
+        sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+        tis = dr.get_task_instances()
+        assert len(tis) == 2
+        for ti in tis:
+            if ti.task_id == "sensor_f":
+                assert ti.state == State.SKIPPED
+            if ti.task_id == "dummy_f":
+                assert ti.state == State.NONE
+
+    def test_basic_sensor_get_upstream_output(self, dag_maker):
+        ret_val = 100
+        sensor_xcom_value = "xcom_value"
+
+        @task
+        def upstream_f() -> int:
+            return ret_val
+
+        @task.sensor
+        def sensor_f(n: int):
+            assert n == ret_val
+            return PokeReturnValue(is_done=True, xcom_value=sensor_xcom_value)
+
+        with dag_maker():
+            uf = upstream_f()
+            sf = sensor_f(uf)
+
+        dr = dag_maker.create_dagrun()
+        uf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+        sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
+        tis = dr.get_task_instances()
+        assert len(tis) == 2
+        for ti in tis:
+            if ti.task_id == "sensor_f":
+                assert ti.state == State.SUCCESS
+            if ti.task_id == "dummy_f":
+                assert ti.state == State.SUCCESS
+        actual_xcom_value = XCom.get_one(
+            key="return_value", task_id="sensor_f", dag_id=dr.dag_id, run_id=dr.run_id
+        )
+        assert actual_xcom_value == sensor_xcom_value