You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2021/08/24 08:56:43 UTC

[airflow] branch main updated: Add pre/post execution hooks (#17576)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3d96ad6  Add pre/post execution hooks (#17576)
3d96ad6 is described below

commit 3d96ad62f91e662b62481441d8eec2994651e122
Author: Malthe Borch <mb...@gmail.com>
AuthorDate: Tue Aug 24 10:56:24 2021 +0200

    Add pre/post execution hooks (#17576)
    
    Adds overrideable pre-/ post- execution hooks. With this change you can override pre-/post- hooks at the time of DAG creation, without the need of creating your own derived operators. This means that you can - for example - skip /fail any task by raising appropriate exception in a method passed as the pre- execution hook based on some criteria (for example you can make a number of tasks always skipped in a development environment). You can also plug-in post-execution behaviour this wa [...]
---
 airflow/models/baseoperator.py                | 18 ++++++++++++++++++
 airflow/operators/subdag.py                   |  2 ++
 tests/models/test_baseoperator.py             | 24 ++++++++++++++++++++++++
 tests/serialization/test_dag_serialization.py |  2 ++
 4 files changed, 46 insertions(+)

diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index a131640..3feeb5b 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -81,6 +81,8 @@ if TYPE_CHECKING:
 ScheduleInterval = Union[str, timedelta, relativedelta]
 
 TaskStateChangeCallback = Callable[[Context], None]
+TaskPreExecuteHook = Callable[[Context], None]
+TaskPostExecuteHook = Callable[[Context, Any], None]
 
 T = TypeVar('T', bound=Callable)
 
@@ -347,6 +349,14 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
     :param on_success_callback: much like the ``on_failure_callback`` except
         that it is executed when the task succeeds.
     :type on_success_callback: TaskStateChangeCallback
+    :param pre_execute: a function to be called immediately before task
+        execution, receiving a context dictionary; raising an exception will
+        prevent the task from being executed.
+    :type pre_execute: TaskPreExecuteHook
+    :param post_execute: a function to be called immediately after task
+        execution, receiving a context dictionary and task result; raising an
+        exception will prevent the task from succeeding.
+    :type post_execute: TaskPostExecuteHook
     :param trigger_rule: defines the rule by which dependencies are applied
         for the task to get triggered. Options are:
         ``{ all_success | all_failed | all_done | one_success |
@@ -488,6 +498,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         on_failure_callback: Optional[TaskStateChangeCallback] = None,
         on_success_callback: Optional[TaskStateChangeCallback] = None,
         on_retry_callback: Optional[TaskStateChangeCallback] = None,
+        pre_execute: Optional[TaskPreExecuteHook] = None,
+        post_execute: Optional[TaskPostExecuteHook] = None,
         trigger_rule: str = TriggerRule.ALL_SUCCESS,
         resources: Optional[Dict] = None,
         run_as_user: Optional[str] = None,
@@ -599,6 +611,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         self.on_failure_callback = on_failure_callback
         self.on_success_callback = on_success_callback
         self.on_retry_callback = on_retry_callback
+        self._pre_execute_hook = pre_execute
+        self._post_execute_hook = post_execute
 
         if isinstance(retry_delay, timedelta):
             self.retry_delay = retry_delay
@@ -960,6 +974,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
     @prepare_lineage
     def pre_execute(self, context: Any):
         """This hook is triggered right before self.execute() is called."""
+        if self._pre_execute_hook is not None:
+            self._pre_execute_hook(context)
 
     def execute(self, context: Any):
         """
@@ -977,6 +993,8 @@ class BaseOperator(Operator, LoggingMixin, TaskMixin, metaclass=BaseOperatorMeta
         It is passed the execution context and any results returned by the
         operator.
         """
+        if self._post_execute_hook is not None:
+            self._post_execute_hook(context, result)
 
     def on_kill(self) -> None:
         """
diff --git a/airflow/operators/subdag.py b/airflow/operators/subdag.py
index 62b9b77..da83bee 100644
--- a/airflow/operators/subdag.py
+++ b/airflow/operators/subdag.py
@@ -156,6 +156,7 @@ class SubDagOperator(BaseSensorOperator):
             session.commit()
 
     def pre_execute(self, context):
+        super().pre_execute(context)
         execution_date = context['execution_date']
         dag_run = self._get_dagrun(execution_date)
 
@@ -184,6 +185,7 @@ class SubDagOperator(BaseSensorOperator):
         return dag_run.state != State.RUNNING
 
     def post_execute(self, context, result=None):
+        super().post_execute(context)
         execution_date = context['execution_date']
         dag_run = self._get_dagrun(execution_date=execution_date)
         self.log.info("Execution finished. State is %s", dag_run.state)
diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py
index ce848db..bf3ea4a 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -541,6 +541,30 @@ class TestBaseOperatorMethods(unittest.TestCase):
             # where the deprecated class was used
             assert warning.filename == __file__
 
+    def test_pre_execute_hook(self):
+        called = False
+
+        def hook(context):
+            nonlocal called
+            called = True
+
+        op = DummyOperator(task_id="test_task", pre_execute=hook)
+        op_copy = op.prepare_for_execution()
+        op_copy.pre_execute({})
+        assert called
+
+    def test_post_execute_hook(self):
+        called = False
+
+        def hook(context, result):
+            nonlocal called
+            called = True
+
+        op = DummyOperator(task_id="test_task", post_execute=hook)
+        op_copy = op.prepare_for_execution()
+        op_copy.post_execute({})
+        assert called
+
 
 class CustomOp(DummyOperator):
     template_fields = ("field", "field2")
diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py
index 826b10f..c387216 100644
--- a/tests/serialization/test_dag_serialization.py
+++ b/tests/serialization/test_dag_serialization.py
@@ -872,6 +872,8 @@ class TestStringifiedDAGs(unittest.TestCase):
             '_log': base_operator.log,
             '_outlets': [],
             '_upstream_task_ids': set(),
+            '_pre_execute_hook': None,
+            '_post_execute_hook': None,
             'depends_on_past': False,
             'do_xcom_push': True,
             'doc': None,