You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by mi...@apache.org on 2020/06/13 04:03:59 UTC
[airflow] branch master updated: Add task instance mutation hook
(#8852)
This is an automated email from the ASF dual-hosted git repository.
milton0825 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new bacb05d Add task instance mutation hook (#8852)
bacb05d is described below
commit bacb05df38532f81a9480f3c3439c6a75e580567
Author: Chao-Han Tsai <mi...@gmail.com>
AuthorDate: Fri Jun 12 21:03:17 2020 -0700
Add task instance mutation hook (#8852)
* Add task instance mutation hook
* add merge
* update docs
* fix
* add missing import
* fix lint
* test state as well
* persist state
* fix lint
---
airflow/models/dagrun.py | 4 ++++
airflow/settings.py | 32 ++++++++++++++++++++---------
docs/concepts.rst | 50 ++++++++++++++++++++++++++++++++++++---------
tests/models/test_dagrun.py | 27 ++++++++++++++++++++++++
4 files changed, 93 insertions(+), 20 deletions(-)
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index d30d110..13a43fc 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -28,6 +28,7 @@ from sqlalchemy.orm.session import Session
from airflow.exceptions import AirflowException
from airflow.models.base import ID_LEN, Base
from airflow.models.taskinstance import TaskInstance as TI
+from airflow.settings import task_instance_mutation_hook
from airflow.stats import Stats
from airflow.ti_deps.dep_context import DepContext
from airflow.ti_deps.dependencies_states import SCHEDULEABLE_STATES
@@ -432,6 +433,7 @@ class DagRun(Base, LoggingMixin):
# check for removed or restored tasks
task_ids = []
for ti in tis:
+ task_instance_mutation_hook(ti)
task_ids.append(ti.task_id)
task = None
try:
@@ -452,6 +454,7 @@ class DagRun(Base, LoggingMixin):
"removed from DAG '{}'".format(ti, dag))
Stats.incr("task_restored_to_dag.{}".format(dag.dag_id), 1, 1)
ti.state = State.NONE
+ session.merge(ti)
# check for missing tasks
for task in dag.task_dict.values():
@@ -463,6 +466,7 @@ class DagRun(Base, LoggingMixin):
"task_instance_created-{}".format(task.__class__.__name__),
1, 1)
ti = TI(task, self.execution_date)
+ task_instance_mutation_hook(ti)
session.add(ti)
session.commit()
diff --git a/airflow/settings.py b/airflow/settings.py
index 281469d..e268a9b 100644
--- a/airflow/settings.py
+++ b/airflow/settings.py
@@ -21,7 +21,7 @@ import json
import logging
import os
import sys
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
import pendulum
from sqlalchemy import create_engine, exc
@@ -36,6 +36,11 @@ from airflow.configuration import AIRFLOW_HOME, WEBSERVER_CONFIG, conf # NOQA F
from airflow.logging_config import configure_logging
from airflow.utils.sqlalchemy import setup_event_handlers
+if TYPE_CHECKING:
+ from airflow.models.baseoperator import BaseOperator
+ from airflow.models.taskinstance import TaskInstance
+
+
log = logging.getLogger(__name__)
@@ -79,18 +84,13 @@ Session: Optional[SASession] = None
json = json
-def policy(task):
+def policy(task: 'BaseOperator'):
"""
- This policy setting allows altering tasks right before they
- are executed. It allows administrator to rewire some task parameters.
-
- Note that the ``Task`` object has a reference to the DAG
- object. So you can use the attributes of all of these to define your
- policy.
+ This policy setting allows altering tasks after they are loaded in
+ the DagBag. It allows administrator to rewire some task parameters.
To define policy, add a ``airflow_local_settings`` module
- to your PYTHONPATH that defines this ``policy`` function. It receives
- a ``Task`` object and can alter it where needed.
+ to your PYTHONPATH that defines this ``policy`` function.
Here are a few examples of how this can be useful:
@@ -103,6 +103,18 @@ def policy(task):
"""
+def task_instance_mutation_hook(task_instance: 'TaskInstance'):
+ """
+ This setting allows altering task instances before they are queued by
+ the Airflow scheduler.
+
+ To define task_instance_mutation_hook, add a ``airflow_local_settings`` module
+ to your PYTHONPATH that defines this ``task_instance_mutation_hook`` function.
+
+ This could be used, for instance, to modify the task instance during retries.
+ """
+
+
def pod_mutation_hook(pod):
"""
This setting allows altering ``kubernetes.client.models.V1Pod`` object
diff --git a/docs/concepts.rst b/docs/concepts.rst
index e016b80..79bc024 100644
--- a/docs/concepts.rst
+++ b/docs/concepts.rst
@@ -1047,10 +1047,18 @@ state.
Cluster Policy
==============
-Your local Airflow settings file can define a ``policy`` function that
-has the ability to mutate task attributes based on other task or DAG
-attributes. It receives a single argument as a reference to task objects,
-and is expected to alter its attributes.
+In case you want to apply cluster-wide mutations to the Airflow tasks,
+you can either mutate the task right after the DAG is loaded or
+mutate the task instance before task execution.
+
+Mutate tasks after DAG loaded
+-----------------------------
+
+To mutate the task right after the DAG is parsed, you can define
+a ``policy`` function in ``airflow_local_settings.py`` that mutates the
+task based on other task or DAG attributes (through ``task.dag``).
+It receives a single argument as a reference to the task object and you can alter
+its attributes.
For example, this function could apply a specific queue property when
using a specific operator, or enforce a task timeout policy, making sure
@@ -1066,13 +1074,35 @@ may look like inside your ``airflow_local_settings.py``:
if task.timeout > timedelta(hours=48):
task.timeout = timedelta(hours=48)
-To define policy, add a ``airflow_local_settings`` module to your :envvar:`PYTHONPATH`
-or to AIRFLOW_HOME/config folder that defines this ``policy`` function. It receives a ``TaskInstance``
-object and can alter it where needed.
-Please note, cluster policy currently applies to task only though you can access DAG via ``task.dag`` property.
-Also, cluster policy will have precedence over task attributes defined in DAG
-meaning if ``task.sla`` is defined in dag and also mutated via cluster policy then later will have precedence.
+Please note, cluster policy will have precedence over task
+attributes defined in DAG meaning if ``task.sla`` is defined
+in dag and also mutated via cluster policy then later will have precedence.
+
+
+Mutate task instances before task execution
+-------------------------------------------
+
+To mutate the task instance before the task execution, you can define a
+``task_instance_mutation_hook`` function in ``airflow_local_settings.py``
+that mutates the task instance.
+
+For example, this function re-routes the task to execute in a different
+queue during retries:
+
+.. code:: python
+
+ def task_instance_mutation_hook(ti):
+ if ti.try_number >= 1:
+ ti.queue = 'retry_queue'
+
+
+Where to put ``airflow_local_settings.py``?
+-------------------------------------------
+
+Add a ``airflow_local_settings.py`` file to your ``$PYTHONPATH``
+or to ``$AIRFLOW_HOME/config`` folder.
+
Documentation & Notes
=====================
diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py
index ee223f8..98e8fa7 100644
--- a/tests/models/test_dagrun.py
+++ b/tests/models/test_dagrun.py
@@ -19,6 +19,7 @@
import datetime
import unittest
+import mock
from parameterized import parameterized
from airflow import models, settings
@@ -564,6 +565,32 @@ class TestDagRun(unittest.TestCase):
flaky_ti.refresh_from_db()
self.assertEqual(State.NONE, flaky_ti.state)
+ @parameterized.expand([(state,) for state in State.task_states])
+ @mock.patch('airflow.models.dagrun.task_instance_mutation_hook')
+ def test_task_instance_mutation_hook(self, state, mock_hook):
+ def mutate_task_instance(task_instance):
+ if task_instance.queue == 'queue1':
+ task_instance.queue = 'queue2'
+ else:
+ task_instance.queue = 'queue1'
+
+ mock_hook.side_effect = mutate_task_instance
+
+ dag = DAG('test_task_instance_mutation_hook', start_date=DEFAULT_DATE)
+ dag.add_task(DummyOperator(task_id='task_to_mutate', owner='test', queue='queue1'))
+
+ dagrun = self.create_dag_run(dag)
+ task = dagrun.get_task_instances()[0]
+ session = settings.Session()
+ task.state = state
+ session.merge(task)
+ session.commit()
+ assert task.queue == 'queue2'
+
+ dagrun.verify_integrity()
+ task = dagrun.get_task_instances()[0]
+ assert task.queue == 'queue1'
+
@parameterized.expand([
(State.SUCCESS, True),
(State.SKIPPED, True),