You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/08/24 08:56:43 UTC
[airflow] branch main updated: Add pre/post execution hooks (#17576)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 3d96ad6 Add pre/post execution hooks (#17576)
3d96ad6 is described below
commit 3d96ad62f91e662b62481441d8eec2994651e122
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Tue Aug 24 10:56:24 2021 +0200
Add pre/post execution hooks (#17576)
Adds overrideable pre-/ post- execution hooks. With this change you can override pre-/post- hooks at the time of DAG creation, without the need of creating your own derived operators. This means that you can - for example - skip /fail any task by raising appropriate exception in a method passed as the pre- execution hook based on some criteria (for example you can make a number of tasks always skipped in a development environment). You can also plug-in post-execution behaviour this wa [...]
---
airflow/models/baseoperator.py | 18 ++++++++++++++++++
airflow/operators/subdag.py | 2 ++
tests/models/test_baseoperator.py | 24 ++++++++++++++++++++++++
tests/serialization/test_dag_serialization.py | 2 ++
4 files changed, 46 insertions(+)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index a131640..3feeb5b 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -81,6 +81,8 @@ if TYPE_CHECKING:
ScheduleInterval = Union[str, timedelta, relativedelta]
TaskStateChangeCallback = Callable[[Context], None]
+TaskPreExecuteHook = Callable[[Context], None]
+TaskPostExecuteHook = Callable[[Context, Any], None]
T = TypeVar('T', bound=Callable)
@@ -347,6 +349,14 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
:param on_success_callback: much like the ``on_failure_callback`` except
that it is executed when the task succeeds.
:type on_success_callback: TaskStateChangeCallback
+ :param pre_execute: a function to be called immediately before task
+ execution, receiving a context dictionary; raising an exception will
+ prevent the task from being executed.
+ :type pre_execute: TaskPreExecuteHook
+ :param post_execute: a function to be called immediately after task
+ execution, receiving a context dictionary and task result; raising an
+ exception will prevent the task from succeeding.
+ :type post_execute: TaskPostExecuteHook
:param trigger_rule: defines the rule by which dependencies are applied
for the task to get triggered. Options are:
``{ all_success | all_failed | all_done | one_success |
@@ -488,6 +498,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
on_failure_callback: Optional[TaskStateChangeCallback] = None,
on_success_callback: Optional[TaskStateChangeCallback] = None,
on_retry_callback: Optional[TaskStateChangeCallback] = None,
+ pre_execute: Optional[TaskPreExecuteHook] = None,
+ post_execute: Optional[TaskPostExecuteHook] = None,
trigger_rule: str = TriggerRule.ALL_SUCCESS,
resources: Optional[Dict] = None,
run_as_user: Optional[str] = None,
@@ -599,6 +611,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
self.on_failure_callback = on_failure_callback
self.on_success_callback = on_success_callback
self.on_retry_callback = on_retry_callback
+ self._pre_execute_hook = pre_execute
+ self._post_execute_hook = post_execute
if isinstance(retry_delay, timedelta):
self.retry_delay = retry_delay
@@ -960,6 +974,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
@prepare_lineage
def pre_execute(self, context: Any):
"""This hook is triggered right before self.execute() is called."""
+ if self._pre_execute_hook is not None:
+ self._pre_execute_hook(context)
def execute(self, context: Any):
"""
@@ -977,6 +993,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
It is passed the execution context and any results returned by the
operator.
"""
+ if self._post_execute_hook is not None:
+ self._post_execute_hook(context, result)
def on_kill(self) -> None:
"""
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 62b9b77..da83bee 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -156,6 +156,7 @@ class SubDagOperator(BaseSensorOperator):
session.commit()
def pre_execute(self, context):
+ super().pre_execute(context)
execution_date = context['execution_date']
dag_run = self._get_dagrun(execution_date)
@@ -184,6 +185,7 @@ class SubDagOperator(BaseSensorOperator):
return dag_run.state != State.RUNNING
def post_execute(self, context, result=None):
+ super().post_execute(context)
execution_date = context['execution_date']
dag_run = self._get_dagrun(execution_date=execution_date)
self.log.info("Execution finished. State is %s", dag_run.state)
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index ce848db..bf3ea4a 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -541,6 +541,30 @@ class TestBaseOperatorMethods(unittest.TestCase):
# where the deprecated class was used
assert warning.filename == __file__
+ def test_pre_execute_hook(self):
+ called = False
+
+ def hook(context):
+ nonlocal called
+ called = True
+
+ op = DummyOperator(task_id="test_task", pre_execute=hook)
+ op_copy = op.prepare_for_execution()
+ op_copy.pre_execute({})
+ assert called
+
+ def test_post_execute_hook(self):
+ called = False
+
+ def hook(context, result):
+ nonlocal called
+ called = True
+
+ op = DummyOperator(task_id="test_task", post_execute=hook)
+ op_copy = op.prepare_for_execution()
+ op_copy.post_execute({})
+ assert called
+
class CustomOp(DummyOperator):
template_fields = ("field", "field2")
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 826b10f..c387216 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -872,6 +872,8 @@ class TestStringifiedDAGs(unittest.TestCase):
'_log': base_operator.log,
'_outlets': [],
'_upstream_task_ids': set(),
+ '_pre_execute_hook': None,
+ '_post_execute_hook': None,
'depends_on_past': False,
'do_xcom_push': True,
'doc': None,