You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by mi...@apache.org on 2020/06/13 04:03:59 UTC

[airflow] branch master updated: Add task instance mutation hook (#8852)

This is an automated email from the ASF dual-hosted git repository.

milton0825 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new bacb05d  Add task instance mutation hook (#8852)
bacb05d is described below

commit bacb05df38532f81a9480f3c3439c6a75e580567
Author: Chao-Han Tsai <mi...@gmail.com>
AuthorDate: Fri Jun 12 21:03:17 2020 -0700

    Add task instance mutation hook (#8852)
    
    * Add task instance mutation hook
    
    * add merge
    
    * update docs
    
    * fix
    
    * add missing import
    
    * fix lint
    
    * test state as well
    
    * persist state
    
    * fix lint
---
 airflow/models/dagrun.py    |  4 ++++
 airflow/settings.py         | 32 ++++++++++++++++++++---------
 docs/concepts.rst           | 50 ++++++++++++++++++++++++++++++++++++---------
 tests/models/test_dagrun.py | 27 ++++++++++++++++++++++++
 4 files changed, 93 insertions(+), 20 deletions(-)

diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index d30d110..13a43fc 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -28,6 +28,7 @@ from sqlalchemy.orm.session import Session
 from airflow.exceptions import AirflowException
 from airflow.models.base import ID_LEN, Base
 from airflow.models.taskinstance import TaskInstance as TI
+from airflow.settings import task_instance_mutation_hook
 from airflow.stats import Stats
 from airflow.ti_deps.dep_context import DepContext
 from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
@@ -432,6 +433,7 @@ class DagRun(Base, LoggingMixin):
         # check for removed or restored tasks
         task_ids = []
         for ti in tis:
+            task_instance_mutation_hook(ti)
             task_ids.append(ti.task_id)
             task = None
             try:
@@ -452,6 +454,7 @@ class DagRun(Base, LoggingMixin):
                               "removed from DAG '{}'".format(ti, dag))
                 Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1)
                 ti.state = State.NONE
+            session.merge(ti)
 
         # check for missing tasks
         for task in dag.task_dict.values():
@@ -463,6 +466,7 @@ class DagRun(Base, LoggingMixin):
                     "task_instance_created-{}".format(task.__class__.__name__),
                     1, 1)
                 ti = TI(task, self.execution_date)
+                task_instance_mutation_hook(ti)
                 session.add(ti)
 
         session.commit()
diff --git a/airflow/settings.py b/airflow/settings.py
index 281469d..e268a9b 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -21,7 +21,7 @@ import json
 import logging
 import os
 import sys
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 import pendulum
 from sqlalchemy import create_engine, exc
@@ -36,6 +36,11 @@ from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf  # NOQA F
 from airflow.logging_config import configure_logging
 from airflow.utils.sqlalchemy import setup_event_handlers
 
+if TYPE_CHECKING:
+    from airflow.models.baseoperator import BaseOperator
+    from airflow.models.taskinstance import TaskInstance
+
+
 log = logging.getLogger(__name__)
 
 
@@ -79,18 +84,13 @@ Session: Optional[SASession] = None
 json = json
 
 
-def policy(task):
+def policy(task: 'BaseOperator'):
     """
-    This policy setting allows altering tasks right before they
-    are executed. It allows administrator to rewire some task parameters.
-
-    Note that the ``Task`` object has a reference to the DAG
-    object. So you can use the attributes of all of these to define your
-    policy.
+    This policy setting allows altering tasks after they are loaded in
+    the DagBag. It allows administrator to rewire some task parameters.
 
     To define policy, add a ``airflow_local_settings`` module
-    to your PYTHONPATH that defines this ``policy`` function. It receives
-    a ``Task`` object and can alter it where needed.
+    to your PYTHONPATH that defines this ``policy`` function.
 
     Here are a few examples of how this can be useful:
 
@@ -103,6 +103,18 @@ def policy(task):
     """
 
 
+def task_instance_mutation_hook(task_instance: 'TaskInstance'):
+    """
+    This setting allows altering task instances before they are queued by
+    the Airflow scheduler.
+
+    To define task_instance_mutation_hook, add a ``airflow_local_settings`` module
+    to your PYTHONPATH that defines this ``task_instance_mutation_hook`` function.
+
+    This could be used, for instance, to modify the task instance during retries.
+    """
+
+
 def pod_mutation_hook(pod):
     """
     This setting allows altering ``kubernetes.client.models.V1Pod`` object
diff --git a/docs/concepts.rst b/docs/concepts.rst
index e016b80..79bc024 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -1047,10 +1047,18 @@ state.
 Cluster Policy
 ==============
 
-Your local Airflow settings file can define a ``policy`` function that
-has the ability to mutate task attributes based on other task or DAG
-attributes. It receives a single argument as a reference to task objects,
-and is expected to alter its attributes.
+In case you want to apply cluster-wide mutations to the Airflow tasks,
+you can either mutate the task right after the DAG is loaded or
+mutate the task instance before task execution.
+
+Mutate tasks after DAG loaded
+-----------------------------
+
+To mutate the task right after the DAG is parsed, you can define
+a ``policy`` function in ``airflow_local_settings.py`` that mutates the
+task based on other task or DAG attributes (through ``task.dag``).
+It receives a single argument as a reference to the task object and you can alter
+its attributes.
 
 For example, this function could apply a specific queue property when
 using a specific operator, or enforce a task timeout policy, making sure
@@ -1066,13 +1074,35 @@ may look like inside your ``airflow_local_settings.py``:
         if task.timeout > timedelta(hours=48):
             task.timeout = timedelta(hours=48)
 
-To define policy, add a ``airflow_local_settings`` module to your :envvar:`PYTHONPATH`
-or to AIRFLOW_HOME/config folder that defines this ``policy`` function. It receives a ``TaskInstance``
-object and can alter it where needed.
 
-Please note, cluster policy currently applies to task only though you can access DAG via ``task.dag`` property.
-Also, cluster policy will have precedence over task attributes defined in DAG
-meaning if ``task.sla`` is defined in dag and also mutated via cluster policy then later will have precedence.
+Please note, cluster policy will have precedence over task
+attributes defined in DAG meaning if ``task.sla`` is defined
+in dag and also mutated via cluster policy then later will have precedence.
+
+
+Mutate task instances before task execution
+-------------------------------------------
+
+To mutate the task instance before the task execution, you can define a
+``task_instance_mutation_hook`` function in ``airflow_local_settings.py``
+that mutates the task instance.
+
+For example, this function re-routes the task to execute in a different
+queue during retries:
+
+.. code:: python
+
+    def task_instance_mutation_hook(ti):
+        if ti.try_number >= 1:
+            ti.queue = 'retry_queue'
+
+
+Where to put ``airflow_local_settings.py``?
+-------------------------------------------
+
+Add a ``airflow_local_settings.py`` file to your ``$PYTHONPATH``
+or to ``$AIRFLOW_HOME/config`` folder.
+
 
 Documentation & Notes
 =====================
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index ee223f8..98e8fa7 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -19,6 +19,7 @@
 import datetime
 import unittest
 
+import mock
 from parameterized import parameterized
 
 from airflow import models, settings
@@ -564,6 +565,32 @@ class TestDagRun(unittest.TestCase):
         flaky_ti.refresh_from_db()
         self.assertEqual(State.NONE, flaky_ti.state)
 
+    @parameterized.expand([(state,) for state in State.task_states])
+    @mock.patch('airflow.models.dagrun.task_instance_mutation_hook')
+    def test_task_instance_mutation_hook(self, state, mock_hook):
+        def mutate_task_instance(task_instance):
+            if task_instance.queue == 'queue1':
+                task_instance.queue = 'queue2'
+            else:
+                task_instance.queue = 'queue1'
+
+        mock_hook.side_effect = mutate_task_instance
+
+        dag = DAG('test_task_instance_mutation_hook', start_date=DEFAULT_DATE)
+        dag.add_task(DummyOperator(task_id='task_to_mutate', owner='test', queue='queue1'))
+
+        dagrun = self.create_dag_run(dag)
+        task = dagrun.get_task_instances()[0]
+        session = settings.Session()
+        task.state = state
+        session.merge(task)
+        session.commit()
+        assert task.queue == 'queue2'
+
+        dagrun.verify_integrity()
+        task = dagrun.get_task_instances()[0]
+        assert task.queue == 'queue1'
+
     @parameterized.expand([
         (State.SUCCESS, True),
         (State.SKIPPED, True),