You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/11/07 02:06:27 UTC
[airflow] branch main updated: Adding sensor decorator (#22562)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new cfd63df786 Adding sensor decorator (#22562)
cfd63df786 is described below
commit cfd63df786e0c40723968cb8078f808ca9d39688
Author: Mingshi <80...@users.noreply.github.com>
AuthorDate: Sun Nov 6 18:06:19 2022 -0800
Adding sensor decorator (#22562)
Co-authored-by: mingshi <mi...@coinbase.com>
---
airflow/decorators/__init__.py | 3 +
airflow/decorators/__init__.pyi | 37 ++++++
airflow/decorators/sensor.py | 74 ++++++++++++
airflow/example_dags/example_sensor_decorator.py | 67 +++++++++++
airflow/sensors/python.py | 6 +-
docs/apache-airflow/tutorial/taskflow.rst | 16 +++
tests/decorators/test_sensor.py | 146 +++++++++++++++++++++++
7 files changed, 346 insertions(+), 3 deletions(-)
diff --git a/airflow/decorators/__init__.py b/airflow/decorators/__init__.py
index 2485b6645a..af478314e1 100644
--- a/airflow/decorators/__init__.py
+++ b/airflow/decorators/__init__.py
@@ -23,6 +23,7 @@ from airflow.decorators.branch_python import branch_task
from airflow.decorators.external_python import external_python_task
from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import virtualenv_task
+from airflow.decorators.sensor import sensor_task
from airflow.decorators.short_circuit import short_circuit_task
from airflow.decorators.task_group import task_group
from airflow.models.dag import dag
@@ -40,6 +41,7 @@ __all__ = [
"external_python_task",
"branch_task",
"short_circuit_task",
+ "sensor_task",
]
@@ -51,6 +53,7 @@ class TaskDecoratorCollection:
external_python = staticmethod(external_python_task)
branch = staticmethod(branch_task)
short_circuit = staticmethod(short_circuit_task)
+ sensor = staticmethod(sensor_task)
__call__: Any = python # Alias '@task' to '@task.python'.
diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi
index fad075ca4f..fd17efa174 100644
--- a/airflow/decorators/__init__.pyi
+++ b/airflow/decorators/__init__.pyi
@@ -29,6 +29,7 @@ from airflow.decorators.branch_python import branch_task
from airflow.decorators.external_python import external_python_task
from airflow.decorators.python import python_task
from airflow.decorators.python_virtualenv import virtualenv_task
+from airflow.decorators.sensor import sensor_task
from airflow.decorators.task_group import task_group
from airflow.kubernetes.secret import Secret
from airflow.models.dag import dag
@@ -45,6 +46,7 @@ __all__ = [
"external_python_task",
"branch_task",
"short_circuit_task",
+ "sensor_task",
]
class TaskDecoratorCollection:
@@ -410,5 +412,40 @@ class TaskDecoratorCollection:
of the target ConfigMap's Data field will represent the key-value
pairs as environment variables. Extends env_from.
"""
+ @overload
+ def sensor(
+ self,
+ *,
+ poke_interval: float = ...,
+ timeout: float = ...,
+ soft_fail: bool = False,
+ mode: str = ...,
+ exponential_backoff: bool = False,
+ **kwargs,
+ ) -> TaskDecorator:
+ """
+ Wraps a Python function into a sensor operator.
+
+ :param poke_interval: Time in seconds that the job should wait in
+ between each try
+ :param timeout: Time, in seconds before the task times out and fails.
+ :param soft_fail: Set to true to mark the task as SKIPPED on failure
+ :param mode: How the sensor operates.
+ Options are: ``{ poke | reschedule }``, default is ``poke``.
+ When set to ``poke`` the sensor is taking up a worker slot for its
+ whole execution time and sleeps between pokes. Use this mode if the
+ expected runtime of the sensor is short or if a short poke interval
+ is required. Note that the sensor will hold onto a worker slot and
+ a pool slot for the duration of the sensor's runtime in this mode.
+ When set to ``reschedule`` the sensor task frees the worker slot when
+ the criteria is not yet met and it's rescheduled at a later time. Use
+ this mode if the time before the criteria is met is expected to be
+ quite long. The poke interval should be more than one minute to
+ prevent too much load on the scheduler.
+ :param exponential_backoff: allow progressive longer waits between
+ pokes by using exponential backoff algorithm
+ """
+ @overload
+ def sensor(self, python_callable: Optional[FParams, FReturn] = None) -> Task[FParams, FReturn]: ...
task: TaskDecoratorCollection
diff --git a/airflow/decorators/sensor.py b/airflow/decorators/sensor.py
new file mode 100644
index 0000000000..c7d1f1181b
--- /dev/null
+++ b/airflow/decorators/sensor.py
@@ -0,0 +1,74 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from typing import Callable, Sequence
+
+from airflow.decorators.base import TaskDecorator, get_unique_task_id, task_decorator_factory
+from airflow.models.taskinstance import Context
+from airflow.sensors.base import PokeReturnValue
+from airflow.sensors.python import PythonSensor
+
+
+class DecoratedSensorOperator(PythonSensor):
+ """
+ Wraps a Python callable and captures args/kwargs when called for execution.
+ :param python_callable: A reference to an object that is callable
+ :param task_id: task Id
+ :param op_args: a list of positional arguments that will get unpacked when
+ calling your callable (templated)
+ :param op_kwargs: a dictionary of keyword arguments that will get unpacked
+ in your function (templated)
+ :param kwargs_to_upstream: For certain operators, we might need to upstream certain arguments
+ that would otherwise be absorbed by the DecoratedOperator (for example python_callable for the
+ PythonOperator). This gives a user the option to upstream kwargs as needed.
+ """
+
+ template_fields: Sequence[str] = ('op_args', 'op_kwargs')
+ template_fields_renderers: dict[str, str] = {"op_args": "py", "op_kwargs": "py"}
+
+ # since we won't mutate the arguments, we should just do the shallow copy
+ # there are some cases we can't deepcopy the objects (e.g protobuf).
+ shallow_copy_attrs: Sequence[str] = ('python_callable',)
+
+ def __init__(
+ self,
+ *,
+ task_id: str,
+ **kwargs,
+ ) -> None:
+ kwargs.pop('multiple_outputs')
+ kwargs['task_id'] = get_unique_task_id(task_id, kwargs.get('dag'), kwargs.get('task_group'))
+ super().__init__(**kwargs)
+
+ def poke(self, context: Context) -> PokeReturnValue:
+ return self.python_callable(*self.op_args, **self.op_kwargs)
+
+
+def sensor_task(python_callable: Callable | None = None, **kwargs) -> TaskDecorator:
+ """
+ Wraps a function into an Airflow operator.
+ Accepts kwargs for operator kwarg. Can be reused in a single DAG.
+ :param python_callable: Function to decorate
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ multiple_outputs=False,
+ decorated_operator_class=DecoratedSensorOperator,
+ **kwargs,
+ )
diff --git a/airflow/example_dags/example_sensor_decorator.py b/airflow/example_dags/example_sensor_decorator.py
new file mode 100644
index 0000000000..08908589e0
--- /dev/null
+++ b/airflow/example_dags/example_sensor_decorator.py
@@ -0,0 +1,67 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Example DAG demonstrating the usage of the sensor decorator."""
+
+from __future__ import annotations
+
+# [START tutorial]
+# [START import_module]
+import pendulum
+
+from airflow.decorators import dag, task
+from airflow.sensors.base import PokeReturnValue
+
+# [END import_module]
+
+
+# [START instantiate_dag]
+@dag(
+ schedule_interval=None,
+ start_date=pendulum.datetime(2021, 1, 1, tz="UTC"),
+ catchup=False,
+ tags=['example'],
+)
+def example_sensor_decorator():
+ # [END instantiate_dag]
+
+ # [START wait_function]
+ # Using a sensor operator to wait for the upstream data to be ready.
+ @task.sensor(poke_interval=60, timeout=3600, mode="reschedule")
+ def wait_for_upstream() -> PokeReturnValue:
+ return PokeReturnValue(is_done=True, xcom_value="xcom_value")
+
+ # [END wait_function]
+
+ # [START dummy_function]
+ @task
+ def dummy_operator() -> None:
+ pass
+
+ # [END dummy_function]
+
+ # [START main_flow]
+ wait_for_upstream() >> dummy_operator()
+ # [END main_flow]
+
+
+# [START dag_invocation]
+tutorial_etl_dag = example_sensor_decorator()
+# [END dag_invocation]
+
+# [END tutorial]
diff --git a/airflow/sensors/python.py b/airflow/sensors/python.py
index d9241d8952..37143dbbb0 100644
--- a/airflow/sensors/python.py
+++ b/airflow/sensors/python.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from typing import Any, Callable, Mapping, Sequence
-from airflow.sensors.base import BaseSensorOperator
+from airflow.sensors.base import BaseSensorOperator, PokeReturnValue
from airflow.utils.context import Context, context_merge
from airflow.utils.operator_helpers import determine_kwargs
@@ -65,10 +65,10 @@ class PythonSensor(BaseSensorOperator):
self.op_kwargs = op_kwargs or {}
self.templates_dict = templates_dict
- def poke(self, context: Context) -> bool:
+ def poke(self, context: Context) -> PokeReturnValue:
context_merge(context, self.op_kwargs, templates_dict=self.templates_dict)
self.op_kwargs = determine_kwargs(self.python_callable, self.op_args, context)
self.log.info("Poking callable: %s", str(self.python_callable))
return_value = self.python_callable(*self.op_args, **self.op_kwargs)
- return bool(return_value)
+ return PokeReturnValue(bool(return_value))
diff --git a/docs/apache-airflow/tutorial/taskflow.rst b/docs/apache-airflow/tutorial/taskflow.rst
index 63a3cfccd8..9db581200e 100644
--- a/docs/apache-airflow/tutorial/taskflow.rst
+++ b/docs/apache-airflow/tutorial/taskflow.rst
@@ -359,6 +359,22 @@ Notes on using the operator:
You should upgrade to Airflow 2.4 or above in order to use it.
+Using the TaskFlow API with Sensor operators
+--------------------------------------------
+You can apply the ``@task.sensor`` decorator to convert a regular Python function to an instance of the
+BaseSensorOperator class. The Python function implements the poke logic and returns an instance of
+the ``PokeReturnValue`` class as the ``poke()`` method in the BaseSensorOperator does. The ``PokeReturnValue`` is
+a new feature in Airflow 2.3 that allows a sensor operator to push an XCom value as described in
+section "Having sensors return XOM values" of :doc:`apache-airflow-providers:howto/create-update-providers`.
+
+.. _taskflow/task_sensor_example:
+
+.. exampleinclude:: /../../airflow/example_dags/example_sensor_decorator.py
+ :language: python
+ :start-after: [START tutorial]
+ :end-before: [END tutorial]
+
+
Multiple outputs inference
--------------------------
Tasks can also infer multiple outputs by using dict Python typing.
diff --git a/tests/decorators/test_sensor.py b/tests/decorators/test_sensor.py
new file mode 100644
index 0000000000..d58fb486aa
--- /dev/null
+++ b/tests/decorators/test_sensor.py
@@ -0,0 +1,146 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import pytest
+
+from airflow.decorators import task
+from airflow.exceptions import AirflowSensorTimeout
+from airflow.models import XCom
+from airflow.sensors.base import PokeReturnValue
+from airflow.utils.state import State
+
+
+class TestSensorDecorator:
+ def test_sensor_fails_on_none_python_callable(self, dag_maker):
+ not_callable = {}
+ with pytest.raises(TypeError):
+ task.sensor(not_callable)
+
+ def test_basic_sensor_success(self, dag_maker):
+ sensor_xcom_value = "xcom_value"
+
+ @task.sensor
+ def sensor_f():
+ return PokeReturnValue(is_done=True, xcom_value=sensor_xcom_value)
+
+ @task
+ def dummy_f():
+ pass
+
+ with dag_maker():
+ sf = sensor_f()
+ df = dummy_f()
+ sf >> df
+
+ dr = dag_maker.create_dagrun()
+ sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+ tis = dr.get_task_instances()
+ assert len(tis) == 2
+ for ti in tis:
+ if ti.task_id == "sensor_f":
+ assert ti.state == State.SUCCESS
+ if ti.task_id == "dummy_f":
+ assert ti.state == State.NONE
+ actual_xcom_value = XCom.get_one(
+ key="return_value", task_id="sensor_f", dag_id=dr.dag_id, run_id=dr.run_id
+ )
+ assert actual_xcom_value == sensor_xcom_value
+
+ def test_basic_sensor_failure(self, dag_maker):
+ @task.sensor(timeout=0)
+ def sensor_f():
+ return PokeReturnValue(is_done=False, xcom_value="xcom_value")
+
+ @task
+ def dummy_f():
+ pass
+
+ with dag_maker():
+ sf = sensor_f()
+ df = dummy_f()
+ sf >> df
+
+ dr = dag_maker.create_dagrun()
+ with pytest.raises(AirflowSensorTimeout):
+ sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+
+ tis = dr.get_task_instances()
+ assert len(tis) == 2
+ for ti in tis:
+ if ti.task_id == "sensor_f":
+ assert ti.state == State.FAILED
+ if ti.task_id == "dummy_f":
+ assert ti.state == State.NONE
+
+ def test_basic_sensor_soft_fail(self, dag_maker):
+ @task.sensor(timeout=0, soft_fail=True)
+ def sensor_f():
+ return PokeReturnValue(is_done=False, xcom_value="xcom_value")
+
+ @task
+ def dummy_f():
+ pass
+
+ with dag_maker():
+ sf = sensor_f()
+ df = dummy_f()
+ sf >> df
+
+ dr = dag_maker.create_dagrun()
+ sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+ tis = dr.get_task_instances()
+ assert len(tis) == 2
+ for ti in tis:
+ if ti.task_id == "sensor_f":
+ assert ti.state == State.SKIPPED
+ if ti.task_id == "dummy_f":
+ assert ti.state == State.NONE
+
+ def test_basic_sensor_get_upstream_output(self, dag_maker):
+ ret_val = 100
+ sensor_xcom_value = "xcom_value"
+
+ @task
+ def upstream_f() -> int:
+ return ret_val
+
+ @task.sensor
+ def sensor_f(n: int):
+ assert n == ret_val
+ return PokeReturnValue(is_done=True, xcom_value=sensor_xcom_value)
+
+ with dag_maker():
+ uf = upstream_f()
+ sf = sensor_f(uf)
+
+ dr = dag_maker.create_dagrun()
+ uf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date, ignore_ti_state=True)
+ sf.operator.run(start_date=dr.execution_date, end_date=dr.execution_date)
+ tis = dr.get_task_instances()
+ assert len(tis) == 2
+ for ti in tis:
+ if ti.task_id == "sensor_f":
+ assert ti.state == State.SUCCESS
+ if ti.task_id == "dummy_f":
+ assert ti.state == State.SUCCESS
+ actual_xcom_value = XCom.get_one(
+ key="return_value", task_id="sensor_f", dag_id=dr.dag_id, run_id=dr.run_id
+ )
+ assert actual_xcom_value == sensor_xcom_value