You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2017/10/05 21:37:31 UTC
incubator-airflow git commit: [AIRFLOW-1634] Adds task_concurrency
feature
Repository: incubator-airflow
Updated Branches:
refs/heads/master 96206b0e5 -> cfc2f73c4
[AIRFLOW-1634] Adds task_concurrency feature
This adds a feature to limit the concurrency of
individual tasks. The
default will be to not change existing behavior.
Closes #2624 from saguziel/aguziel-task-
concurrency
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/cfc2f73c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/cfc2f73c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/cfc2f73c
Branch: refs/heads/master
Commit: cfc2f73c445074e1e09d6ef6a056cd2b33a945da
Parents: 96206b0
Author: Alex Guziel <al...@airbnb.com>
Authored: Thu Oct 5 14:37:26 2017 -0700
Committer: Alex Guziel <al...@airbnb.com>
Committed: Thu Oct 5 14:37:26 2017 -0700
----------------------------------------------------------------------
airflow/jobs.py | 51 +++++++--
airflow/models.py | 21 +++-
airflow/ti_deps/dep_context.py | 4 +-
airflow/ti_deps/deps/task_concurrency_dep.py | 37 +++++++
airflow/utils/dag_processing.py | 29 ++++-
tests/jobs.py | 127 +++++++++++++++++++---
tests/models.py | 56 ++++++++++
tests/ti_deps/deps/test_task_concurrency.py | 51 +++++++++
8 files changed, 348 insertions(+), 28 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 8ca81dc..2675bd3 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -1024,6 +1024,30 @@ class SchedulerJob(BaseJob):
)
@provide_session
+ def __get_task_concurrency_map(self, states, session=None):
+ """
+ Returns a map from tasks to number in the states list given.
+
+ :param states: List of states to query for
+ :type states: List[State]
+ :return: A map from (dag_id, task_id) to count of tasks in states
+ :rtype: Dict[[String, String], Int]
+
+ """
+ TI = models.TaskInstance
+ ti_concurrency_query = (
+ session
+ .query(TI.task_id, TI.dag_id, func.count('*'))
+ .filter(TI.state.in_(states))
+ .group_by(TI.task_id, TI.dag_id)
+ ).all()
+ task_map = defaultdict(int)
+ for result in ti_concurrency_query:
+ task_id, dag_id, count = result
+ task_map[(dag_id, task_id)] = count
+ return task_map
+
+ @provide_session
def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
"""
Finds TIs that are ready for execution with respect to pool limits,
@@ -1038,6 +1062,9 @@ class SchedulerJob(BaseJob):
:type states: Tuple[State]
:return: List[TaskInstance]
"""
+ # TODO(saguziel): Change this to include QUEUED, for concurrency
+ # purposes we may want to count queued tasks
+ states_to_count_as_running = [State.RUNNING]
executable_tis = []
# Get all the queued task instances from associated with scheduled
@@ -1082,6 +1109,8 @@ class SchedulerJob(BaseJob):
for task_instance in task_instances_to_examine:
pool_to_task_instances[task_instance.pool].append(task_instance)
+ task_concurrency_map = self.__get_task_concurrency_map(states=states_to_count_as_running, session=session)
+
# Go through each pool, and queue up a task for execution if there are
# any open slots in the pool.
for pool, task_instances in pool_to_task_instances.items():
@@ -1119,6 +1148,7 @@ class SchedulerJob(BaseJob):
# Check to make sure that the task concurrency of the DAG hasn't been
# reached.
dag_id = task_instance.dag_id
+ simple_dag = simple_dag_bag.get_dag(dag_id)
if dag_id not in dag_id_to_possibly_running_task_count:
# TODO(saguziel): also check against QUEUED state, see AIRFLOW-1104
@@ -1126,7 +1156,7 @@ class SchedulerJob(BaseJob):
DAG.get_num_task_instances(
dag_id,
simple_dag_bag.get_dag(dag_id).task_ids,
- states=[State.RUNNING],
+ states=states_to_count_as_running,
session=session)
current_task_concurrency = dag_id_to_possibly_running_task_count[dag_id]
@@ -1143,6 +1173,16 @@ class SchedulerJob(BaseJob):
)
continue
+ task_concurrency = simple_dag.get_task_special_arg(task_instance.task_id, 'task_concurrency')
+ if task_concurrency is not None:
+ num_running = task_concurrency_map[((task_instance.dag_id, task_instance.task_id))]
+ if num_running >= task_concurrency:
+ self.logger.info("Not executing %s since the task concurrency for this task"
+ " has been reached.", task_instance)
+ continue
+ else:
+ task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1
+
if self.executor.has_task(task_instance):
self.log.debug(
"Not handling task %s as the executor reports it is running",
@@ -1723,16 +1763,9 @@ class SchedulerJob(BaseJob):
if pickle_dags:
pickle_id = dag.pickle(session).id
- task_ids = [task.task_id for task in dag.tasks]
-
# Only return DAGs that are not paused
if dag_id not in paused_dag_ids:
- simple_dags.append(SimpleDag(dag.dag_id,
- task_ids,
- dag.full_filepath,
- dag.concurrency,
- dag.is_paused,
- pickle_id))
+ simple_dags.append(SimpleDag(dag, pickle_id=pickle_id))
if len(self.dag_ids) > 0:
dags = [dag for dag in dagbag.dags.values()
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index e764d85..e3c52b5 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -65,6 +65,7 @@ from airflow.dag.base_dag import BaseDag, BaseDagBag
from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
from airflow.utils.dates import cron_presets, date_range as utils_date_range
@@ -1835,6 +1836,15 @@ class TaskInstance(Base, LoggingMixin):
else:
return pull_fn(task_id=task_ids)
+ @provide_session
+ def get_num_running_task_instances(self, session):
+ TI = TaskInstance
+ return session.query(TI).filter(
+ TI.dag_id == self.dag_id,
+ TI.task_id == self.task_id,
+ TI.state == State.RUNNING
+ ).count()
+
class TaskFail(Base):
"""
@@ -2058,6 +2068,9 @@ class BaseOperator(LoggingMixin):
:type resources: dict
:param run_as_user: unix username to impersonate while running the task
:type run_as_user: str
+ :param task_concurrency: When set, a task will be able to limit the concurrent
+ runs across execution_dates
+ :type task_concurrency: int
"""
# For derived classes to define which fields will get jinjaified
@@ -2100,6 +2113,7 @@ class BaseOperator(LoggingMixin):
trigger_rule=TriggerRule.ALL_SUCCESS,
resources=None,
run_as_user=None,
+ task_concurrency=None,
*args,
**kwargs):
@@ -2165,6 +2179,7 @@ class BaseOperator(LoggingMixin):
self.priority_weight = priority_weight
self.resources = Resources(**(resources or {}))
self.run_as_user = run_as_user
+ self.task_concurrency = task_concurrency
# Private attributes
self._upstream_task_ids = []
@@ -4542,8 +4557,9 @@ class DagRun(Base, LoggingMixin):
session=session
)
none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
+ none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks)
# small speed up
- if unfinished_tasks and none_depends_on_past:
+ if unfinished_tasks and none_depends_on_past and none_task_concurrency:
# todo: this can actually get pretty slow: one task costs between 0.01-015s
no_dependencies_met = True
for ut in unfinished_tasks:
@@ -4581,7 +4597,8 @@ class DagRun(Base, LoggingMixin):
self.state = State.SUCCESS
# if *all tasks* are deadlocked, the run failed
- elif unfinished_tasks and none_depends_on_past and no_dependencies_met:
+ elif (unfinished_tasks and none_depends_on_past and
+ none_task_concurrency and no_dependencies_met):
self.log.info('Deadlock; marking run %s failed', self)
self.state = State.FAILED
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/ti_deps/dep_context.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py
index 01e01dd..f461a81 100644
--- a/airflow/ti_deps/dep_context.py
+++ b/airflow/ti_deps/dep_context.py
@@ -19,6 +19,7 @@ from airflow.ti_deps.deps.not_running_dep import NotRunningDep
from airflow.ti_deps.deps.not_skipped_dep import NotSkippedDep
from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep
from airflow.ti_deps.deps.valid_state_dep import ValidStateDep
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
from airflow.utils.state import State
@@ -97,7 +98,8 @@ QUEUE_DEPS = {
# Dependencies that need to be met for a given task instance to be able to get run by an
# executor. This class just extends QueueContext by adding dependencies for resources.
RUN_DEPS = QUEUE_DEPS | {
- DagTISlotsAvailableDep()
+ DagTISlotsAvailableDep(),
+ TaskConcurrencyDep(),
}
# TODO(aoen): SCHEDULER_DEPS is not coupled to actual execution in any way and
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/ti_deps/deps/task_concurrency_dep.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py b/airflow/ti_deps/deps/task_concurrency_dep.py
new file mode 100644
index 0000000..99df5ac
--- /dev/null
+++ b/airflow/ti_deps/deps/task_concurrency_dep.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+from airflow.utils.db import provide_session
+
+
+class TaskConcurrencyDep(BaseTIDep):
+ """
+ This restricts the number of running task instances for a particular task.
+ """
+ NAME = "Task Concurrency"
+ IGNOREABLE = True
+ IS_TASK_DEP = True
+
+ @provide_session
+ def _get_dep_statuses(self, ti, session, dep_context):
+ if ti.task.task_concurrency is None:
+ yield self._passing_status(reason="Task concurrency is not set.")
+ return
+
+ if ti.get_num_running_task_instances(session) >= ti.task.task_concurrency:
+ yield self._failing_status(reason="The max task concurrency has been reached.")
+ return
+ else:
+ yield self._passing_status(reason="The max task concurrency has not been reached.")
+ return
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/utils/dag_processing.py
----------------------------------------------------------------------
diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py
index d8c13ea..b80f701 100644
--- a/airflow/utils/dag_processing.py
+++ b/airflow/utils/dag_processing.py
@@ -42,7 +42,8 @@ class SimpleDag(BaseDag):
full_filepath,
concurrency,
is_paused,
- pickle_id):
+ pickle_id,
+ task_special_args):
"""
:param dag_id: ID of the DAG
:type dag_id: unicode
@@ -66,6 +67,22 @@ class SimpleDag(BaseDag):
self._is_paused = is_paused
self._concurrency = concurrency
self._pickle_id = pickle_id
+ self._task_special_args = task_special_args
+
+ def __init__(self, dag, pickle_id=None):
+ self._dag_id = dag.dag_id
+ self._task_ids = [task.task_id for task in dag.tasks]
+ self._full_filepath = dag.full_filepath
+ self._is_paused = dag.is_paused
+ self._concurrency = dag.concurrency
+ self._pickle_id = pickle_id
+ self._task_special_args = {}
+ for task in dag.tasks:
+ special_args = {}
+ if task.task_concurrency is not None:
+ special_args['task_concurrency'] = task.task_concurrency
+ if len(special_args) > 0:
+ self._task_special_args[task.task_id] = special_args
@property
def dag_id(self):
@@ -115,6 +132,16 @@ class SimpleDag(BaseDag):
"""
return self._pickle_id
+ @property
+ def task_special_args(self):
+ return self._task_special_args
+
+ def get_task_special_arg(self, task_id, special_arg_name):
+ if task_id in self._task_special_args and special_arg_name in self._task_special_args[task_id]:
+ return self._task_special_args[task_id][special_arg_name]
+ else:
+ return None
+
class SimpleDagBag(BaseDagBag):
"""
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index 0a7f213..ba08fd6 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -41,7 +41,7 @@ from airflow.utils.dates import days_ago
from airflow.utils.db import provide_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
-from airflow.utils.dag_processing import SimpleDagBag, list_py_file_paths
+from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths
from mock import Mock, patch
from sqlalchemy.orm.session import make_transient
@@ -932,13 +932,16 @@ class SchedulerJobTest(unittest.TestCase):
scheduler.heartrate = 0
scheduler.run()
+ def _make_simple_dag_bag(self, dags):
+ return SimpleDagBag([SimpleDag(dag) for dag in dags])
+
def test_execute_task_instances_is_paused_wont_execute(self):
dag_id = 'SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute'
task_id_1 = 'dummy_task'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -968,7 +971,7 @@ class SchedulerJobTest(unittest.TestCase):
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -993,7 +996,7 @@ class SchedulerJobTest(unittest.TestCase):
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1018,7 +1021,7 @@ class SchedulerJobTest(unittest.TestCase):
task_id_1 = 'dummy'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1058,7 +1061,7 @@ class SchedulerJobTest(unittest.TestCase):
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a')
task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b')
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1099,7 +1102,7 @@ class SchedulerJobTest(unittest.TestCase):
task_id_1 = 'dummy'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1117,7 +1120,7 @@ class SchedulerJobTest(unittest.TestCase):
task_id_1 = 'dummy'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1158,6 +1161,98 @@ class SchedulerJobTest(unittest.TestCase):
self.assertEqual(0, len(res))
+ def test_find_executable_task_instances_task_concurrency(self):
+ dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency'
+ task_id_1 = 'dummy'
+ task_id_2 = 'dummy2'
+ dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
+ task1 = DummyOperator(dag=dag, task_id=task_id_1, task_concurrency=2)
+ task2 = DummyOperator(dag=dag, task_id=task_id_2)
+ dagbag = self._make_simple_dag_bag([dag])
+
+ scheduler = SchedulerJob(**self.default_scheduler_args)
+ session = settings.Session()
+
+ dr1 = scheduler.create_dag_run(dag)
+ dr2 = scheduler.create_dag_run(dag)
+ dr3 = scheduler.create_dag_run(dag)
+
+ ti1_1 = TI(task1, dr1.execution_date)
+ ti2 = TI(task2, dr1.execution_date)
+
+ ti1_1.state = State.SCHEDULED
+ ti2.state = State.SCHEDULED
+ session.merge(ti1_1)
+ session.merge(ti2)
+ session.commit()
+
+ res = scheduler._find_executable_task_instances(
+ dagbag,
+ states=[State.SCHEDULED],
+ session=session)
+
+ self.assertEqual(2, len(res))
+
+ ti1_1.state = State.RUNNING
+ ti2.state = State.RUNNING
+ ti1_2 = TI(task1, dr2.execution_date)
+ ti1_2.state = State.SCHEDULED
+ session.merge(ti1_1)
+ session.merge(ti2)
+ session.merge(ti1_2)
+ session.commit()
+
+ res = scheduler._find_executable_task_instances(
+ dagbag,
+ states=[State.SCHEDULED],
+ session=session)
+
+ self.assertEqual(1, len(res))
+
+ ti1_2.state = State.RUNNING
+ ti1_3 = TI(task1, dr3.execution_date)
+ ti1_3.state = State.SCHEDULED
+ session.merge(ti1_2)
+ session.merge(ti1_3)
+ session.commit()
+
+ res = scheduler._find_executable_task_instances(
+ dagbag,
+ states=[State.SCHEDULED],
+ session=session)
+
+ self.assertEqual(0, len(res))
+
+ ti1_1.state = State.SCHEDULED
+ ti1_2.state = State.SCHEDULED
+ ti1_3.state = State.SCHEDULED
+ session.merge(ti1_1)
+ session.merge(ti1_2)
+ session.merge(ti1_3)
+ session.commit()
+
+ res = scheduler._find_executable_task_instances(
+ dagbag,
+ states=[State.SCHEDULED],
+ session=session)
+
+ self.assertEqual(2, len(res))
+
+ ti1_1.state = State.RUNNING
+ ti1_2.state = State.SCHEDULED
+ ti1_3.state = State.SCHEDULED
+ session.merge(ti1_1)
+ session.merge(ti1_2)
+ session.merge(ti1_3)
+ session.commit()
+
+ res = scheduler._find_executable_task_instances(
+ dagbag,
+ states=[State.SCHEDULED],
+ session=session)
+
+ self.assertEqual(1, len(res))
+
def test_change_state_for_executable_task_instances_no_tis(self):
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1169,7 +1264,7 @@ class SchedulerJobTest(unittest.TestCase):
task_id_1 = 'dummy'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1201,7 +1296,7 @@ class SchedulerJobTest(unittest.TestCase):
task_id_1 = 'dummy'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1237,7 +1332,7 @@ class SchedulerJobTest(unittest.TestCase):
task_id_1 = 'dummy'
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1282,7 +1377,7 @@ class SchedulerJobTest(unittest.TestCase):
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
task2 = DummyOperator(dag=dag, task_id=task_id_2)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
session = settings.Session()
@@ -1343,7 +1438,7 @@ class SchedulerJobTest(unittest.TestCase):
dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
task1 = DummyOperator(dag=dag, task_id=task_id_1)
task2 = DummyOperator(dag=dag, task_id=task_id_2)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(**self.default_scheduler_args)
scheduler.max_tis_per_query = 3
@@ -1410,16 +1505,18 @@ class SchedulerJobTest(unittest.TestCase):
ti2.state = State.SCHEDULED
session.commit()
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
scheduler = SchedulerJob(num_runs=0, run_duration=0)
scheduler._change_state_for_tis_without_dagrun(simple_dag_bag=dagbag,
old_states=[State.SCHEDULED, State.QUEUED],
new_state=State.NONE,
session=session)
+ ti = dr.get_task_instance(task_id='dummy', session=session)
ti.refresh_from_db(session=session)
self.assertEqual(ti.state, State.SCHEDULED)
+ ti2 = dr2.get_task_instance(task_id='dummy', session=session)
ti2.refresh_from_db(session=session)
self.assertEqual(ti2.state, State.SCHEDULED)
@@ -2042,7 +2139,7 @@ class SchedulerJobTest(unittest.TestCase):
queue = []
scheduler._process_task_instances(dag, queue=queue)
self.assertEquals(len(queue), 2)
- dagbag = SimpleDagBag([dag])
+ dagbag = self._make_simple_dag_bag([dag])
# Recreated part of the scheduler here, to kick off tasks -> executor
for ti_key in queue:
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index db5beca..a1de17d 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -517,6 +517,39 @@ class DagRunTest(unittest.TestCase):
dr.update_state()
self.assertEqual(dr.state, State.FAILED)
+ def test_dagrun_no_deadlock(self):
+ session = settings.Session()
+ dag = DAG('test_dagrun_no_deadlock',
+ start_date=DEFAULT_DATE)
+ with dag:
+ op1 = DummyOperator(task_id='dop', depends_on_past=True)
+ op2 = DummyOperator(task_id='tc', task_concurrency=1)
+
+ dag.clear()
+ dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1',
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE,
+ start_date=DEFAULT_DATE)
+ dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2',
+ state=State.RUNNING,
+ execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
+ start_date=DEFAULT_DATE + datetime.timedelta(days=1))
+ ti1_op1 = dr.get_task_instance(task_id='dop')
+ ti2_op1 = dr2.get_task_instance(task_id='dop')
+ ti2_op1 = dr.get_task_instance(task_id='tc')
+ ti2_op2 = dr.get_task_instance(task_id='tc')
+ ti1_op1.set_state(state=State.RUNNING, session=session)
+ dr.update_state()
+ dr2.update_state()
+ self.assertEqual(dr.state, State.RUNNING)
+ self.assertEqual(dr2.state, State.RUNNING)
+
+ ti2_op1.set_state(state=State.RUNNING, session=session)
+ dr.update_state()
+ dr2.update_state()
+ self.assertEqual(dr.state, State.RUNNING)
+ self.assertEqual(dr2.state, State.RUNNING)
+
def test_get_task_instance_on_empty_dagrun(self):
"""
Make sure that a proper value is returned when a dagrun has no task instances
@@ -1201,6 +1234,29 @@ class TaskInstanceTest(unittest.TestCase):
ti = TI(
task=task2, execution_date=datetime.datetime.now())
self.assertFalse(ti._check_and_change_state_before_execution())
+
+ def test_get_num_running_task_instances(self):
+ session = settings.Session()
+
+ dag = models.DAG(dag_id='test_get_num_running_task_instances')
+ dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy')
+ task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
+ task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE)
+
+ ti1 = TI(task=task, execution_date=DEFAULT_DATE)
+ ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
+ ti3 = TI(task=task2, execution_date=DEFAULT_DATE)
+ ti1.state = State.RUNNING
+ ti2.state = State.QUEUED
+ ti3.state = State.RUNNING
+ session.add(ti1)
+ session.add(ti2)
+ session.add(ti3)
+ session.commit()
+
+ self.assertEquals(1, ti1.get_num_running_task_instances(session=session))
+ self.assertEquals(1, ti2.get_num_running_task_instances(session=session))
+ self.assertEquals(1, ti3.get_num_running_task_instances(session=session))
class ClearTasksTest(unittest.TestCase):
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/ti_deps/deps/test_task_concurrency.py
----------------------------------------------------------------------
diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py
new file mode 100644
index 0000000..77a5990
--- /dev/null
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from datetime import datetime
+from mock import Mock
+
+from airflow.models import DAG, BaseOperator
+from airflow.ti_deps.dep_context import DepContext
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
+from airflow.utils.state import State
+
+
+class TaskConcurrencyDepTest(unittest.TestCase):
+
+ def _get_task(self, **kwargs):
+ return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs)
+
+ def test_not_task_concurrency(self):
+ task = self._get_task(start_date=datetime(2016, 1, 1))
+ dep_context = DepContext()
+ ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+ self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+
+ def test_not_reached_concurrency(self):
+ task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=1)
+ dep_context = DepContext()
+ ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+ ti.get_num_running_task_instances = lambda x: 0
+ self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+
+ def test_reached_concurrency(self):
+ task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=2)
+ dep_context = DepContext()
+ ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+ ti.get_num_running_task_instances = lambda x: 1
+ self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+ ti.get_num_running_task_instances = lambda x: 2
+ self.assertFalse(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+