You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2017/04/05 07:59:58 UTC

incubator-airflow git commit: [AIRFLOW-111] Include queued tasks in scheduler concurrency check

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 0371df4f1 -> 3ff5abee3


[AIRFLOW-111] Include queued tasks in scheduler concurrency check

The concurrency argument in dags appears to not be
obeyed because the
scheduler does not check the concurrency properly
when checking tasks.
The tasks do not run, but this leads to a lot of
scheduler churn.

Closes #2214 from saguziel/aguziel-fix-concurrency


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/3ff5abee
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/3ff5abee
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/3ff5abee

Branch: refs/heads/master
Commit: 3ff5abee3f9d29e545e021c2c060e9c9f3045236
Parents: 0371df4
Author: Alex Guziel <al...@airbnb.com>
Authored: Wed Apr 5 09:59:53 2017 +0200
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Wed Apr 5 09:59:53 2017 +0200

----------------------------------------------------------------------
 airflow/jobs.py   | 25 +++++++++++---------
 airflow/models.py | 48 ++++++++++++++++++++++----------------
 tests/jobs.py     | 62 ++++++++++++++++++++++++++++++++++++++++++++++++++
 tests/models.py   | 38 +++++++++++++++++++++++++++++++
 4 files changed, 142 insertions(+), 31 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3ff5abee/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index b5c2d5d..0d4ae7f 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -43,7 +43,7 @@ from tabulate import tabulate
 from airflow import executors, models, settings
 from airflow import configuration as conf
 from airflow.exceptions import AirflowException
-from airflow.models import DagRun
+from airflow.models import DAG, DagRun
 from airflow.settings import Stats
 from airflow.task_runner import get_task_runner
 from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
@@ -1036,7 +1036,7 @@ class SchedulerJob(BaseJob):
                 task_instances, key=lambda ti: (-ti.priority_weight, ti.execution_date))
 
             # DAG IDs with running tasks that equal the concurrency limit of the dag
-            dag_id_to_running_task_count = {}
+            dag_id_to_possibly_running_task_count = {}
 
             for task_instance in priority_sorted_task_instances:
                 if open_slots <= 0:
@@ -1063,22 +1063,24 @@ class SchedulerJob(BaseJob):
                 # reached.
                 dag_id = task_instance.dag_id
 
-                if dag_id not in dag_id_to_running_task_count:
-                    dag_id_to_running_task_count[dag_id] = \
-                        DagRun.get_running_tasks(
-                            session,
+                if dag_id not in dag_id_to_possibly_running_task_count:
+                    dag_id_to_possibly_running_task_count[dag_id] = \
+                        DAG.get_num_task_instances(
                             dag_id,
-                            simple_dag_bag.get_dag(dag_id).task_ids)
+                            simple_dag_bag.get_dag(dag_id).task_ids,
+                            states=[State.RUNNING, State.QUEUED],
+                            session=session)
 
-                current_task_concurrency = dag_id_to_running_task_count[dag_id]
+                current_task_concurrency = dag_id_to_possibly_running_task_count[dag_id]
                 task_concurrency_limit = simple_dag_bag.get_dag(dag_id).concurrency
-                self.logger.info("DAG {} has {}/{} running tasks"
+                self.logger.info("DAG {} has {}/{} running and queued tasks"
                                  .format(dag_id,
                                          current_task_concurrency,
                                          task_concurrency_limit))
-                if current_task_concurrency > task_concurrency_limit:
+                if current_task_concurrency >= task_concurrency_limit:
                     self.logger.info("Not executing {} since the number "
-                                     "of tasks running from DAG {} is >= to the "
+                                     "of tasks running or queued from DAG {}"
+                                     " is >= to the "
                                      "DAG's task concurrency limit of {}"
                                      .format(task_instance,
                                              dag_id,
@@ -1137,6 +1139,7 @@ class SchedulerJob(BaseJob):
                     queue=queue)
 
                 open_slots -= 1
+                dag_id_to_possibly_running_task_count[dag_id] += 1
 
     def _process_dags(self, dagbag, dags, tis_out):
         """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3ff5abee/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 8628100..8a91cc2 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -3505,6 +3505,34 @@ class DAG(BaseDag, LoggingMixin):
             session.merge(dag)
             session.commit()
 
+    @staticmethod
+    @provide_session
+    def get_num_task_instances(dag_id, task_ids, states=None, session=None):
+        """
+        Returns the number of task instances in the given DAG.
+
+        :param session: ORM session
+        :param dag_id: ID of the DAG to get the task concurrency of
+        :type dag_id: unicode
+        :param task_ids: A list of valid task IDs for the given DAG
+        :type task_ids: list[unicode]
+        :param states: A list of states to filter by if supplied
+        :type states: list[state]
+        :return: The number of running tasks
+        :rtype: int
+        """
+        qry = session.query(func.count(TaskInstance.task_id)).filter(
+            TaskInstance.dag_id == dag_id,
+            TaskInstance.task_id.in_(task_ids))
+        if states is not None:
+            if None in states:
+                qry = qry.filter(or_(
+                    TaskInstance.state.in_(states),
+                    TaskInstance.state.is_(None)))
+            else:
+                qry = qry.filter(TaskInstance.state.in_(states))
+        return qry.scalar()
+
 
 class Chart(Base):
     __tablename__ = "chart"
@@ -4166,26 +4194,6 @@ class DagRun(Base):
         session.commit()
 
     @staticmethod
-    def get_running_tasks(session, dag_id, task_ids):
-        """
-        Returns the number of tasks running in the given DAG.
-
-        :param session: ORM session
-        :param dag_id: ID of the DAG to get the task concurrency of
-        :type dag_id: unicode
-        :param task_ids: A list of valid task IDs for the given DAG
-        :type task_ids: list[unicode]
-        :return: The number of running tasks
-        :rtype: int
-        """
-        qry = session.query(func.count(TaskInstance.task_id)).filter(
-            TaskInstance.dag_id == dag_id,
-            TaskInstance.task_id.in_(task_ids),
-            TaskInstance.state == State.RUNNING,
-        )
-        return qry.scalar()
-
-    @staticmethod
     def get_run(session, dag_id, execution_date):
         """
         :param dag_id: DAG ID

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3ff5abee/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index 3eb407b..e3caa5d 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -481,6 +481,68 @@ class SchedulerJobTest(unittest.TestCase):
         scheduler.heartrate = 0
         scheduler.run()
 
+    def test_concurrency(self):
+        dag_id = 'SchedulerJobTest.test_concurrency'
+        task_id_1 = 'dummy_task'
+        task_id_2 = 'dummy_task_nonexistent_queue'
+        # important that len(tasks) is less than concurrency
+        # because before scheduler._execute_task_instances would only
+        # check the num tasks once so if concurrency was 3,
+        # we could execute arbitrarily many tasks in the second run
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1)
+        task2 = DummyOperator(dag=dag, task_id=task_id_2)
+        dagbag = SimpleDagBag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        # create first dag run with 1 running and 1 queued
+        dr1 = scheduler.create_dag_run(dag)
+        ti1 = TI(task1, dr1.execution_date)
+        ti2 = TI(task2, dr1.execution_date)
+        ti1.refresh_from_db()
+        ti2.refresh_from_db()
+        ti1.state = State.RUNNING
+        ti2.state = State.QUEUED
+        session.merge(ti1)
+        session.merge(ti2)
+        session.commit()
+
+        self.assertEqual(State.RUNNING, dr1.state)
+        self.assertEqual(2, DAG.get_num_task_instances(dag_id, dag.task_ids,
+            states=[State.RUNNING, State.QUEUED], session=session))
+
+        # create second dag run
+        dr2 = scheduler.create_dag_run(dag)
+        ti3 = TI(task1, dr2.execution_date)
+        ti4 = TI(task2, dr2.execution_date)
+        ti3.refresh_from_db()
+        ti4.refresh_from_db()
+        # manually set to scheduled so we can pick them up
+        ti3.state = State.SCHEDULED
+        ti4.state = State.SCHEDULED
+        session.merge(ti3)
+        session.merge(ti4)
+        session.commit()
+
+        self.assertEqual(State.RUNNING, dr2.state)
+
+        scheduler._execute_task_instances(dagbag, [State.SCHEDULED])
+
+        # check that concurrency is respected
+        ti1.refresh_from_db()
+        ti2.refresh_from_db()
+        ti3.refresh_from_db()
+        ti4.refresh_from_db()
+        self.assertEqual(3, DAG.get_num_task_instances(dag_id, dag.task_ids,
+            states=[State.RUNNING, State.QUEUED], session=session))
+        self.assertEqual(State.RUNNING, ti1.state)
+        self.assertEqual(State.QUEUED, ti2.state)
+        six.assertCountEqual(self, [State.QUEUED, State.SCHEDULED], [ti3.state, ti4.state])
+
+        session.close()
+
     @provide_session
     def evaluate_dagrun(
             self,

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/3ff5abee/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index a013f8a..15450dd 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -195,6 +195,44 @@ class DagTest(unittest.TestCase):
 
         self.assertEquals(tuple(), dag.topological_sort())
 
+    def test_get_num_task_instances(self):
+        test_dag_id = 'test_get_num_task_instances_dag'
+        test_task_id = 'task_1'
+
+        test_dag = DAG(dag_id=test_dag_id, start_date=DEFAULT_DATE)
+        test_task = DummyOperator(task_id=test_task_id, dag=test_dag)
+
+        ti1 = TI(task=test_task, execution_date=DEFAULT_DATE)
+        ti1.state = None
+        ti2 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
+        ti2.state = State.RUNNING
+        ti3 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=2))
+        ti3.state = State.QUEUED
+        ti4 = TI(task=test_task, execution_date=DEFAULT_DATE + datetime.timedelta(days=3))
+        ti4.state = State.RUNNING
+        session = settings.Session()
+        session.merge(ti1)
+        session.merge(ti2)
+        session.merge(ti3)
+        session.merge(ti4)
+        session.commit()
+
+        self.assertEqual(0, DAG.get_num_task_instances(test_dag_id, ['fakename'],
+            session=session))
+        self.assertEqual(4, DAG.get_num_task_instances(test_dag_id, [test_task_id],
+            session=session))
+        self.assertEqual(4, DAG.get_num_task_instances(test_dag_id,
+            ['fakename', test_task_id], session=session))
+        self.assertEqual(1, DAG.get_num_task_instances(test_dag_id, [test_task_id],
+            states=[None], session=session))
+        self.assertEqual(2, DAG.get_num_task_instances(test_dag_id, [test_task_id],
+            states=[State.RUNNING], session=session))
+        self.assertEqual(3, DAG.get_num_task_instances(test_dag_id, [test_task_id],
+            states=[None, State.RUNNING], session=session))
+        self.assertEqual(4, DAG.get_num_task_instances(test_dag_id, [test_task_id],
+            states=[None, State.QUEUED, State.RUNNING], session=session))
+        session.close()
+
 
 class DagRunTest(unittest.TestCase):