You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/13 19:49:10 UTC
[airflow] 01/08: Run mini scheduler in LocalTaskJob during task
exit (#16289)
This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch v2-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
commit c009c2a05138d3b9ec68037a39d14bad1fa60115
Author: Ephraim Anierobi <sp...@gmail.com>
AuthorDate: Thu Jun 10 14:29:30 2021 +0100
Run mini scheduler in LocalTaskJob during task exit (#16289)
Currently, the chances of tasks being killed by the LocalTaskJob heartbeat is high.
This is because, after marking a task successful/failed in Taskinstance.py and mini scheduler is enabled,
we start running the mini scheduler. Whenever the mini scheduling takes time and meet the next job heartbeat,
the heartbeat detects that this task has succeeded with no return code because LocalTaskJob.handle_task_exit
was not called after the task succeeded. Hence, the heartbeat thinks that this task was externally marked failed/successful.
This change resolves this by moving the mini scheduler to LocalTaskJob at the handle_task_exit method ensuring
that the task will no longer be killed by the next heartbeat
(cherry picked from commit 408bd26c22913af93d05aa70abc3c66c52cd4588)
---
airflow/jobs/local_task_job.py | 68 +++++++++++---
airflow/models/taskinstance.py | 60 +-----------
tests/cli/commands/test_task_command.py | 4 +-
tests/jobs/test_local_task_job.py | 161 +++++++++++++++++++++++---------
tests/models/test_taskinstance.py | 103 --------------------
5 files changed, 176 insertions(+), 220 deletions(-)
diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py
index 3afc801..e7f61f1 100644
--- a/airflow/jobs/local_task_job.py
+++ b/airflow/jobs/local_task_job.py
@@ -16,21 +16,24 @@
# specific language governing permissions and limitations
# under the License.
#
-
import signal
from typing import Optional
import psutil
+from sqlalchemy.exc import OperationalError
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
+from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
+from airflow.sentry import Sentry
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
+from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State
@@ -159,7 +162,8 @@ class LocalTaskJob(BaseJob):
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error)
if not self.task_instance.test_mode:
- self._update_dagrun_state_for_paused_dag()
+ if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
+ self._run_mini_scheduler_on_child_tasks()
def on_kill(self):
self.task_runner.terminate()
@@ -215,14 +219,52 @@ class LocalTaskJob(BaseJob):
self.terminating = True
@provide_session
- def _update_dagrun_state_for_paused_dag(self, session=None):
- """
- Checks for paused dags with DagRuns in the running state and
- update the DagRun state if possible
- """
- dag = self.task_instance.task.dag
- if dag.get_is_paused():
- dag_run = self.task_instance.get_dagrun(session=session)
- if dag_run:
- dag_run.dag = dag
- dag_run.update_state(session=session, execute_callbacks=True)
+ @Sentry.enrich_errors
+ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
+ try:
+ # Re-select the row with a lock
+ dag_run = with_row_locks(
+ session.query(DagRun).filter_by(
+ dag_id=self.dag_id,
+ execution_date=self.task_instance.execution_date,
+ ),
+ session=session,
+ ).one()
+
+ # Get a partial dag with just the specific tasks we want to
+ # examine. In order for dep checks to work correctly, we
+ # include ourself (so TriggerRuleDep can check the state of the
+ # task we just executed)
+ task = self.task_instance.task
+
+ partial_dag = task.dag.partial_subset(
+ task.downstream_task_ids,
+ include_downstream=False,
+ include_upstream=False,
+ include_direct_upstream=True,
+ )
+
+ dag_run.dag = partial_dag
+ info = dag_run.task_instance_scheduling_decisions(session)
+
+ skippable_task_ids = {
+ task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
+ }
+
+ schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
+ for schedulable_ti in schedulable_tis:
+ if not hasattr(schedulable_ti, "task"):
+ schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)
+
+ num = dag_run.schedule_tis(schedulable_tis)
+ self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
+
+ session.commit()
+ except OperationalError as e:
+ # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
+ self.log.info(
+ "Skipping mini scheduling run due to exception: %s",
+ e.statement,
+ exc_info=True,
+ )
+ session.rollback()
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index b99fa34..5fb8155 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -35,7 +35,6 @@ import lazy_object_proxy
import pendulum
from jinja2 import TemplateAssertionError, UndefinedError
from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_
-from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
@@ -70,7 +69,7 @@ from airflow.utils.net import get_hostname
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
from airflow.utils.timeout import timeout
@@ -1223,62 +1222,6 @@ class TaskInstance(Base, LoggingMixin):
session.commit()
- if not test_mode:
- self._run_mini_scheduler_on_child_tasks(session)
-
- @provide_session
- @Sentry.enrich_errors
- def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
- if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
- from airflow.models.dagrun import DagRun # Avoid circular import
-
- try:
- # Re-select the row with a lock
- dag_run = with_row_locks(
- session.query(DagRun).filter_by(
- dag_id=self.dag_id,
- execution_date=self.execution_date,
- ),
- session=session,
- ).one()
-
- # Get a partial dag with just the specific tasks we want to
- # examine. In order for dep checks to work correctly, we
- # include ourself (so TriggerRuleDep can check the state of the
- # task we just executed)
- partial_dag = self.task.dag.partial_subset(
- self.task.downstream_task_ids,
- include_downstream=False,
- include_upstream=False,
- include_direct_upstream=True,
- )
-
- dag_run.dag = partial_dag
- info = dag_run.task_instance_scheduling_decisions(session)
-
- skippable_task_ids = {
- task_id
- for task_id in partial_dag.task_ids
- if task_id not in self.task.downstream_task_ids
- }
-
- schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
- for schedulable_ti in schedulable_tis:
- if not hasattr(schedulable_ti, "task"):
- schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id)
-
- num = dag_run.schedule_tis(schedulable_tis)
- self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
-
- session.commit()
- except OperationalError as e:
- # Any kind of DB error here is _non fatal_ as this block is just an optimisation.
- self.log.info(
- f"Skipping mini scheduling run due to exception: {e.statement}",
- exc_info=True,
- )
- session.rollback()
-
def _prepare_and_execute_task_with_callbacks(self, context, task):
"""Prepare Task for Execution"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
@@ -1440,6 +1383,7 @@ class TaskInstance(Base, LoggingMixin):
session=session,
)
if not res:
+ self.log.info("CHECK AND CHANGE")
return
try:
diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py
index f50ddbc..2b93e6d 100644
--- a/tests/cli/commands/test_task_command.py
+++ b/tests/cli/commands/test_task_command.py
@@ -71,8 +71,7 @@ class TestCliTasks(unittest.TestCase):
args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree'])
task_command.task_list(args)
- @mock.patch("airflow.models.taskinstance.TaskInstance._run_mini_scheduler_on_child_tasks")
- def test_test(self, mock_run_mini_scheduler):
+ def test_test(self):
"""Test the `airflow test` command"""
args = self.parser.parse_args(
["tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01']
@@ -81,7 +80,6 @@ class TestCliTasks(unittest.TestCase):
with redirect_stdout(io.StringIO()) as stdout:
task_command.task_test(args)
- mock_run_mini_scheduler.assert_not_called()
# Check that prints, and log messages, are shown
assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()
diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py
index ed43198..9d80647 100644
--- a/tests/jobs/test_local_task_job.py
+++ b/tests/jobs/test_local_task_job.py
@@ -33,6 +33,7 @@ from airflow import settings
from airflow.exceptions import AirflowException, AirflowFailException
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.local_task_job import LocalTaskJob
+from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
@@ -44,9 +45,8 @@ from airflow.utils.net import get_hostname
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
-from airflow.utils.types import DagRunType
from tests.test_utils.asserts import assert_queries_count
-from tests.test_utils.db import clear_db_jobs, clear_db_runs
+from tests.test_utils.config import conf_vars
from tests.test_utils.mock_executor import MockExecutor
# pylint: skip-file
@@ -57,15 +57,25 @@ TEST_DAG_FOLDER = os.environ['AIRFLOW__CORE__DAGS_FOLDER']
class TestLocalTaskJob(unittest.TestCase):
def setUp(self):
- clear_db_jobs()
- clear_db_runs()
+ db.clear_db_dags()
+ db.clear_db_jobs()
+ db.clear_db_runs()
+ db.clear_db_task_fail()
patcher = patch('airflow.jobs.base_job.sleep')
self.addCleanup(patcher.stop)
self.mock_base_job_sleep = patcher.start()
def tearDown(self) -> None:
- clear_db_jobs()
- clear_db_runs()
+ db.clear_db_dags()
+ db.clear_db_jobs()
+ db.clear_db_runs()
+ db.clear_db_task_fail()
+
+ def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
+ for task_id, expected_state in ti_state_mapping.items():
+ task_instance = dag_run.get_task_instance(task_id=task_id)
+ task_instance.refresh_from_db()
+ assert task_instance.state == expected_state, error_message
def test_localtaskjob_essential_attr(self):
"""
@@ -660,57 +670,122 @@ class TestLocalTaskJob(unittest.TestCase):
if ti.state == State.RUNNING and ti.pid is not None:
break
time.sleep(0.2)
- assert ti.state == State.RUNNING
assert ti.pid is not None
+ assert ti.state == State.RUNNING
os.kill(ti.pid, signal_type)
process.join(timeout=10)
assert failure_callback_called.value == 1
assert task_terminated_externally.value == 1
assert not process.is_alive()
- def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
- """Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
- dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
- op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True)
+ @parameterized.expand(
+ [
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'True'},
+ {'A': 'B', 'B': 'C'},
+ {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED},
+ "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.",
+ ),
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'False'},
+ {'A': 'B', 'B': 'C'},
+ {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE},
+ None,
+ "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.",
+ ),
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'True'},
+ {'A': 'B', 'C': 'B', 'D': 'C'},
+ {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
+ {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
+ None,
+ "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.",
+ ),
+ (
+ {('scheduler', 'schedule_after_task_execution'): 'True'},
+ {'A': 'C', 'B': 'C'},
+ {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE},
+ {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED},
+ None,
+ "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.",
+ ),
+ ]
+ )
+ def test_fast_follow(
+ self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
+ ):
+ # pylint: disable=too-many-locals
+ with conf_vars(conf):
+ session = settings.Session()
+
+ dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)
+
+ dag_model = DagModel(
+ dag_id=dag.dag_id,
+ next_dagrun=dag.start_date,
+ is_active=True,
+ )
+ session.add(dag_model)
+ session.flush()
- session = settings.Session()
- orm_dag = DagModel(
- dag_id=dag.dag_id,
- has_task_concurrency_limits=False,
- next_dagrun=dag.start_date,
- next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE),
- is_active=True,
- is_paused=True,
- )
- session.add(orm_dag)
- session.flush()
- # Write Dag to DB
- dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
- dagbag.bag_dag(dag, root_dag=dag)
- dagbag.sync_to_db()
+ python_callable = lambda: True
+ with dag:
+ task_a = PythonOperator(task_id='A', python_callable=python_callable)
+ task_b = PythonOperator(task_id='B', python_callable=python_callable)
+ task_c = PythonOperator(task_id='C', python_callable=python_callable)
+ if 'D' in init_state:
+ task_d = PythonOperator(task_id='D', python_callable=python_callable)
+ for upstream, downstream in dependencies.items():
+ dag.set_dependency(upstream, downstream)
- dr = dag.create_dagrun(
- run_type=DagRunType.SCHEDULED,
- state=State.RUNNING,
- execution_date=DEFAULT_DATE,
- start_date=DEFAULT_DATE,
- session=session,
- )
- assert dr.state == State.RUNNING
- ti = TaskInstance(op1, dr.execution_date)
- job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
- job1.task_runner = StandardTaskRunner(job1)
- job1.run()
- session.add(dr)
- session.refresh(dr)
- assert dr.state == State.SUCCESS
+ scheduler_job = SchedulerJob(subdir=os.devnull)
+ scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
+
+ dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING)
+
+ task_instance_a = TaskInstance(task_a, dag_run.execution_date, init_state['A'])
+
+ task_instance_b = TaskInstance(task_b, dag_run.execution_date, init_state['B'])
+
+ task_instance_c = TaskInstance(task_c, dag_run.execution_date, init_state['C'])
+
+ if 'D' in init_state:
+ task_instance_d = TaskInstance(task_d, dag_run.execution_date, init_state['D'])
+ session.merge(task_instance_d)
+
+ session.merge(task_instance_a)
+ session.merge(task_instance_b)
+ session.merge(task_instance_c)
+ session.flush()
+
+ job1 = LocalTaskJob(
+ task_instance=task_instance_a, ignore_ti_state=True, executor=SequentialExecutor()
+ )
+ job1.task_runner = StandardTaskRunner(job1)
+
+ job2 = LocalTaskJob(
+ task_instance=task_instance_b, ignore_ti_state=True, executor=SequentialExecutor()
+ )
+ job2.task_runner = StandardTaskRunner(job2)
+
+ settings.engine.dispose()
+ job1.run()
+ self.validate_ti_states(dag_run, first_run_state, error_message)
+ if second_run_state:
+ job2.run()
+ self.validate_ti_states(dag_run, second_run_state, error_message)
+ if scheduler_job.processor_agent:
+ scheduler_job.processor_agent.end()
@pytest.fixture()
def clean_db_helper():
yield
- clear_db_jobs()
- clear_db_runs()
+ db.clear_db_jobs()
+ db.clear_db_runs()
@pytest.mark.usefixtures("clean_db_helper")
diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py
index 021809b..c1882e1 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -34,10 +34,8 @@ from sqlalchemy.orm.session import Session
from airflow import models, settings
from airflow.exceptions import AirflowException, AirflowFailException, AirflowSkipException
-from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models import (
DAG,
- DagModel,
DagRun,
Pool,
RenderedTaskInstanceFields,
@@ -1963,107 +1961,6 @@ class TestTaskInstance(unittest.TestCase):
with create_session() as session:
session.query(RenderedTaskInstanceFields).delete()
- def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
- for task_id, expected_state in ti_state_mapping.items():
- task_instance = dag_run.get_task_instance(task_id=task_id)
- assert task_instance.state == expected_state, error_message
-
- @parameterized.expand(
- [
- (
- {('scheduler', 'schedule_after_task_execution'): 'True'},
- {'A': 'B', 'B': 'C'},
- {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED},
- "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.",
- ),
- (
- {('scheduler', 'schedule_after_task_execution'): 'False'},
- {'A': 'B', 'B': 'C'},
- {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE},
- None,
- "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.",
- ),
- (
- {('scheduler', 'schedule_after_task_execution'): 'True'},
- {'A': 'B', 'C': 'B', 'D': 'C'},
- {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
- {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
- None,
- "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.",
- ),
- (
- {('scheduler', 'schedule_after_task_execution'): 'True'},
- {'A': 'C', 'B': 'C'},
- {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE},
- {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED},
- None,
- "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.",
- ),
- ]
- )
- def test_fast_follow(
- self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
- ):
- with conf_vars(conf):
- session = settings.Session()
-
- dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)
-
- dag_model = DagModel(
- dag_id=dag.dag_id,
- next_dagrun=dag.start_date,
- is_active=True,
- )
- session.add(dag_model)
- session.flush()
-
- python_callable = lambda: True
- with dag:
- task_a = PythonOperator(task_id='A', python_callable=python_callable)
- task_b = PythonOperator(task_id='B', python_callable=python_callable)
- task_c = PythonOperator(task_id='C', python_callable=python_callable)
- if 'D' in init_state:
- task_d = PythonOperator(task_id='D', python_callable=python_callable)
- for upstream, downstream in dependencies.items():
- dag.set_dependency(upstream, downstream)
-
- scheduler_job = SchedulerJob(subdir=os.devnull)
- scheduler_job.dagbag.bag_dag(dag, root_dag=dag)
-
- dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING)
-
- task_instance_a = dag_run.get_task_instance(task_id=task_a.task_id)
- task_instance_a.task = task_a
- task_instance_a.set_state(init_state['A'])
-
- task_instance_b = dag_run.get_task_instance(task_id=task_b.task_id)
- task_instance_b.task = task_b
- task_instance_b.set_state(init_state['B'])
-
- task_instance_c = dag_run.get_task_instance(task_id=task_c.task_id)
- task_instance_c.task = task_c
- task_instance_c.set_state(init_state['C'])
-
- if 'D' in init_state:
- task_instance_d = dag_run.get_task_instance(task_id=task_d.task_id)
- task_instance_d.task = task_d
- task_instance_d.state = init_state['D']
-
- session.commit()
- task_instance_a.run()
-
- self.validate_ti_states(dag_run, first_run_state, error_message)
-
- if second_run_state:
- scheduler_job._critical_section_execute_task_instances(session=session)
- task_instance_b.run()
- self.validate_ti_states(dag_run, second_run_state, error_message)
- if scheduler_job.processor_agent:
- scheduler_job.processor_agent.end()
-
def test_set_state_up_for_retry(self):
dag = DAG('dag', start_date=DEFAULT_DATE)
op1 = DummyOperator(task_id='op_1', owner='test', dag=dag)