You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by sa...@apache.org on 2017/10/05 21:37:31 UTC

incubator-airflow git commit: [AIRFLOW-1634] Adds task_concurrency feature

Repository: incubator-airflow
Updated Branches:
  refs/heads/master 96206b0e5 -> cfc2f73c4


[AIRFLOW-1634] Adds task_concurrency feature

This adds a feature to limit the concurrency of
individual tasks. The
default will be to not change existing behavior.

Closes #2624 from saguziel/aguziel-task-
concurrency


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/cfc2f73c
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/cfc2f73c
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/cfc2f73c

Branch: refs/heads/master
Commit: cfc2f73c445074e1e09d6ef6a056cd2b33a945da
Parents: 96206b0
Author: Alex Guziel <al...@airbnb.com>
Authored: Thu Oct 5 14:37:26 2017 -0700
Committer: Alex Guziel <al...@airbnb.com>
Committed: Thu Oct 5 14:37:26 2017 -0700

----------------------------------------------------------------------
 airflow/jobs.py                              |  51 +++++++--
 airflow/models.py                            |  21 +++-
 airflow/ti_deps/dep_context.py               |   4 +-
 airflow/ti_deps/deps/task_concurrency_dep.py |  37 +++++++
 airflow/utils/dag_processing.py              |  29 ++++-
 tests/jobs.py                                | 127 +++++++++++++++++++---
 tests/models.py                              |  56 ++++++++++
 tests/ti_deps/deps/test_task_concurrency.py  |  51 +++++++++
 8 files changed, 348 insertions(+), 28 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/jobs.py
----------------------------------------------------------------------
diff --git a/airflow/jobs.py b/airflow/jobs.py
index 8ca81dc..2675bd3 100644
--- a/airflow/jobs.py
+++ b/airflow/jobs.py
@@ -1024,6 +1024,30 @@ class SchedulerJob(BaseJob):
             )
 
     @provide_session
+    def __get_task_concurrency_map(self, states, session=None):
+        """
+        Returns a map from tasks to number in the states list given.
+
+        :param states: List of states to query for
+        :type states: List[State]
+        :return: A map from (dag_id, task_id) to count of tasks in states
+        :rtype: Dict[[String, String], Int]
+
+        """
+        TI = models.TaskInstance
+        ti_concurrency_query = (
+            session
+            .query(TI.task_id, TI.dag_id, func.count('*'))
+            .filter(TI.state.in_(states))
+            .group_by(TI.task_id, TI.dag_id)
+        ).all()
+        task_map = defaultdict(int)
+        for result in ti_concurrency_query:
+            task_id, dag_id, count = result
+            task_map[(dag_id, task_id)] = count
+        return task_map
+
+    @provide_session
     def _find_executable_task_instances(self, simple_dag_bag, states, session=None):
         """
         Finds TIs that are ready for execution with respect to pool limits,
@@ -1038,6 +1062,9 @@ class SchedulerJob(BaseJob):
         :type states: Tuple[State]
         :return: List[TaskInstance]
         """
+        # TODO(saguziel): Change this to include QUEUED, for concurrency
+        # purposes we may want to count queued tasks
+        states_to_count_as_running = [State.RUNNING]
         executable_tis = []
 
         # Get all the queued task instances from associated with scheduled
@@ -1082,6 +1109,8 @@ class SchedulerJob(BaseJob):
         for task_instance in task_instances_to_examine:
             pool_to_task_instances[task_instance.pool].append(task_instance)
 
+        task_concurrency_map = self.__get_task_concurrency_map(states=states_to_count_as_running, session=session)
+
         # Go through each pool, and queue up a task for execution if there are
         # any open slots in the pool.
         for pool, task_instances in pool_to_task_instances.items():
@@ -1119,6 +1148,7 @@ class SchedulerJob(BaseJob):
                 # Check to make sure that the task concurrency of the DAG hasn't been
                 # reached.
                 dag_id = task_instance.dag_id
+                simple_dag = simple_dag_bag.get_dag(dag_id)
 
                 if dag_id not in dag_id_to_possibly_running_task_count:
                     # TODO(saguziel): also check against QUEUED state, see AIRFLOW-1104
@@ -1126,7 +1156,7 @@ class SchedulerJob(BaseJob):
                         DAG.get_num_task_instances(
                             dag_id,
                             simple_dag_bag.get_dag(dag_id).task_ids,
-                            states=[State.RUNNING],
+                            states=states_to_count_as_running,
                             session=session)
 
                 current_task_concurrency = dag_id_to_possibly_running_task_count[dag_id]
@@ -1143,6 +1173,16 @@ class SchedulerJob(BaseJob):
                     )
                     continue
 
+                task_concurrency = simple_dag.get_task_special_arg(task_instance.task_id, 'task_concurrency')
+                if task_concurrency is not None:
+                    num_running = task_concurrency_map[((task_instance.dag_id, task_instance.task_id))]
+                    if num_running >= task_concurrency:
+                        self.logger.info("Not executing %s since the task concurrency for this task"
+                                         " has been reached.", task_instance)
+                        continue
+                    else:
+                        task_concurrency_map[(task_instance.dag_id, task_instance.task_id)] += 1
+
                 if self.executor.has_task(task_instance):
                     self.log.debug(
                         "Not handling task %s as the executor reports it is running",
@@ -1723,16 +1763,9 @@ class SchedulerJob(BaseJob):
             if pickle_dags:
                 pickle_id = dag.pickle(session).id
 
-            task_ids = [task.task_id for task in dag.tasks]
-
             # Only return DAGs that are not paused
             if dag_id not in paused_dag_ids:
-                simple_dags.append(SimpleDag(dag.dag_id,
-                                             task_ids,
-                                             dag.full_filepath,
-                                             dag.concurrency,
-                                             dag.is_paused,
-                                             pickle_id))
+                simple_dags.append(SimpleDag(dag, pickle_id=pickle_id))
 
         if len(self.dag_ids) > 0:
             dags = [dag for dag in dagbag.dags.values()

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index e764d85..e3c52b5 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -65,6 +65,7 @@ from airflow.dag.base_dag import BaseDag, BaseDagBag
 from airflow.ti_deps.deps.not_in_retry_period_dep import NotInRetryPeriodDep
 from airflow.ti_deps.deps.prev_dagrun_dep import PrevDagrunDep
 from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
 
 from airflow.ti_deps.dep_context import DepContext, QUEUE_DEPS, RUN_DEPS
 from airflow.utils.dates import cron_presets, date_range as utils_date_range
@@ -1835,6 +1836,15 @@ class TaskInstance(Base, LoggingMixin):
         else:
             return pull_fn(task_id=task_ids)
 
+    @provide_session
+    def get_num_running_task_instances(self, session):
+        TI = TaskInstance
+        return session.query(TI).filter(
+            TI.dag_id == self.dag_id,
+            TI.task_id == self.task_id,
+            TI.state == State.RUNNING
+        ).count()
+
 
 class TaskFail(Base):
     """
@@ -2058,6 +2068,9 @@ class BaseOperator(LoggingMixin):
     :type resources: dict
     :param run_as_user: unix username to impersonate while running the task
     :type run_as_user: str
+    :param task_concurrency: When set, a task will be able to limit the concurrent
+        runs across execution_dates
+    :type task_concurrency: int
     """
 
     # For derived classes to define which fields will get jinjaified
@@ -2100,6 +2113,7 @@ class BaseOperator(LoggingMixin):
             trigger_rule=TriggerRule.ALL_SUCCESS,
             resources=None,
             run_as_user=None,
+            task_concurrency=None,
             *args,
             **kwargs):
 
@@ -2165,6 +2179,7 @@ class BaseOperator(LoggingMixin):
         self.priority_weight = priority_weight
         self.resources = Resources(**(resources or {}))
         self.run_as_user = run_as_user
+        self.task_concurrency = task_concurrency
 
         # Private attributes
         self._upstream_task_ids = []
@@ -4542,8 +4557,9 @@ class DagRun(Base, LoggingMixin):
             session=session
         )
         none_depends_on_past = all(not t.task.depends_on_past for t in unfinished_tasks)
+        none_task_concurrency = all(t.task.task_concurrency is None for t in unfinished_tasks)
         # small speed up
-        if unfinished_tasks and none_depends_on_past:
+        if unfinished_tasks and none_depends_on_past and none_task_concurrency:
             # todo: this can actually get pretty slow: one task costs between 0.01-015s
             no_dependencies_met = True
             for ut in unfinished_tasks:
@@ -4581,7 +4597,8 @@ class DagRun(Base, LoggingMixin):
                 self.state = State.SUCCESS
 
             # if *all tasks* are deadlocked, the run failed
-            elif unfinished_tasks and none_depends_on_past and no_dependencies_met:
+            elif (unfinished_tasks and none_depends_on_past and
+                  none_task_concurrency and no_dependencies_met):
                 self.log.info('Deadlock; marking run %s failed', self)
                 self.state = State.FAILED
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/ti_deps/dep_context.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/dep_context.py b/airflow/ti_deps/dep_context.py
index 01e01dd..f461a81 100644
--- a/airflow/ti_deps/dep_context.py
+++ b/airflow/ti_deps/dep_context.py
@@ -19,6 +19,7 @@ from airflow.ti_deps.deps.not_running_dep import NotRunningDep
 from airflow.ti_deps.deps.not_skipped_dep import NotSkippedDep
 from airflow.ti_deps.deps.runnable_exec_date_dep import RunnableExecDateDep
 from airflow.ti_deps.deps.valid_state_dep import ValidStateDep
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
 from airflow.utils.state import State
 
 
@@ -97,7 +98,8 @@ QUEUE_DEPS = {
 # Dependencies that need to be met for a given task instance to be able to get run by an
 # executor. This class just extends QueueContext by adding dependencies for resources.
 RUN_DEPS = QUEUE_DEPS | {
-    DagTISlotsAvailableDep()
+    DagTISlotsAvailableDep(),
+    TaskConcurrencyDep(),
 }
 
 # TODO(aoen): SCHEDULER_DEPS is not coupled to actual execution in any way and

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/ti_deps/deps/task_concurrency_dep.py
----------------------------------------------------------------------
diff --git a/airflow/ti_deps/deps/task_concurrency_dep.py b/airflow/ti_deps/deps/task_concurrency_dep.py
new file mode 100644
index 0000000..99df5ac
--- /dev/null
+++ b/airflow/ti_deps/deps/task_concurrency_dep.py
@@ -0,0 +1,37 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
+from airflow.utils.db import provide_session
+
+
+class TaskConcurrencyDep(BaseTIDep):
+    """
+    This restricts the number of running task instances for a particular task.
+    """
+    NAME = "Task Concurrency"
+    IGNOREABLE = True
+    IS_TASK_DEP = True
+
+    @provide_session
+    def _get_dep_statuses(self, ti, session, dep_context):
+        if ti.task.task_concurrency is None:
+            yield self._passing_status(reason="Task concurrency is not set.")
+            return
+
+        if ti.get_num_running_task_instances(session) >= ti.task.task_concurrency:
+            yield self._failing_status(reason="The max task concurrency has been reached.")
+            return
+        else:
+            yield self._passing_status(reason="The max task concurrency has not been reached.")
+            return

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/airflow/utils/dag_processing.py
----------------------------------------------------------------------
diff --git a/airflow/utils/dag_processing.py b/airflow/utils/dag_processing.py
index d8c13ea..b80f701 100644
--- a/airflow/utils/dag_processing.py
+++ b/airflow/utils/dag_processing.py
@@ -42,7 +42,8 @@ class SimpleDag(BaseDag):
                  full_filepath,
                  concurrency,
                  is_paused,
-                 pickle_id):
+                 pickle_id,
+                 task_special_args):
         """
         :param dag_id: ID of the DAG
         :type dag_id: unicode
@@ -66,6 +67,22 @@ class SimpleDag(BaseDag):
         self._is_paused = is_paused
         self._concurrency = concurrency
         self._pickle_id = pickle_id
+        self._task_special_args = task_special_args
+
+    def __init__(self, dag, pickle_id=None):
+        self._dag_id = dag.dag_id
+        self._task_ids = [task.task_id for task in dag.tasks]
+        self._full_filepath = dag.full_filepath
+        self._is_paused = dag.is_paused
+        self._concurrency = dag.concurrency
+        self._pickle_id = pickle_id
+        self._task_special_args = {}
+        for task in dag.tasks:
+            special_args = {}
+            if task.task_concurrency is not None:
+                special_args['task_concurrency'] = task.task_concurrency
+            if len(special_args) > 0:
+                self._task_special_args[task.task_id] = special_args
 
     @property
     def dag_id(self):
@@ -115,6 +132,16 @@ class SimpleDag(BaseDag):
         """
         return self._pickle_id
 
+    @property
+    def task_special_args(self):
+        return self._task_special_args
+
+    def get_task_special_arg(self, task_id, special_arg_name):
+        if task_id in self._task_special_args and special_arg_name in self._task_special_args[task_id]:
+            return self._task_special_args[task_id][special_arg_name]
+        else:
+            return None
+
 
 class SimpleDagBag(BaseDagBag):
     """

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/jobs.py
----------------------------------------------------------------------
diff --git a/tests/jobs.py b/tests/jobs.py
index 0a7f213..ba08fd6 100644
--- a/tests/jobs.py
+++ b/tests/jobs.py
@@ -41,7 +41,7 @@ from airflow.utils.dates import days_ago
 from airflow.utils.db import provide_session
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
-from airflow.utils.dag_processing import SimpleDagBag, list_py_file_paths
+from airflow.utils.dag_processing import SimpleDag, SimpleDagBag, list_py_file_paths
 
 from mock import Mock, patch
 from sqlalchemy.orm.session import make_transient
@@ -932,13 +932,16 @@ class SchedulerJobTest(unittest.TestCase):
         scheduler.heartrate = 0
         scheduler.run()
 
+    def _make_simple_dag_bag(self, dags):
+        return SimpleDagBag([SimpleDag(dag) for dag in dags])
+
     def test_execute_task_instances_is_paused_wont_execute(self):
         dag_id = 'SchedulerJobTest.test_execute_task_instances_is_paused_wont_execute'
         task_id_1 = 'dummy_task'
 
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -968,7 +971,7 @@ class SchedulerJobTest(unittest.TestCase):
 
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -993,7 +996,7 @@ class SchedulerJobTest(unittest.TestCase):
 
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1018,7 +1021,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1058,7 +1061,7 @@ class SchedulerJobTest(unittest.TestCase):
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1, pool='a')
         task2 = DummyOperator(dag=dag, task_id=task_id_2, pool='b')
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1099,7 +1102,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1117,7 +1120,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1158,6 +1161,98 @@ class SchedulerJobTest(unittest.TestCase):
 
         self.assertEqual(0, len(res))
 
+    def test_find_executable_task_instances_task_concurrency(self):
+        dag_id = 'SchedulerJobTest.test_find_executable_task_instances_task_concurrency'
+        task_id_1 = 'dummy'
+        task_id_2 = 'dummy2'
+        dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
+        task1 = DummyOperator(dag=dag, task_id=task_id_1, task_concurrency=2)
+        task2 = DummyOperator(dag=dag, task_id=task_id_2)
+        dagbag = self._make_simple_dag_bag([dag])
+
+        scheduler = SchedulerJob(**self.default_scheduler_args)
+        session = settings.Session()
+
+        dr1 = scheduler.create_dag_run(dag)
+        dr2 = scheduler.create_dag_run(dag)
+        dr3 = scheduler.create_dag_run(dag)
+
+        ti1_1 = TI(task1, dr1.execution_date)
+        ti2 = TI(task2, dr1.execution_date)
+
+        ti1_1.state = State.SCHEDULED
+        ti2.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti2)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(2, len(res))
+
+        ti1_1.state = State.RUNNING
+        ti2.state = State.RUNNING
+        ti1_2 = TI(task1, dr2.execution_date)
+        ti1_2.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti2)
+        session.merge(ti1_2)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(1, len(res))
+
+        ti1_2.state = State.RUNNING
+        ti1_3 = TI(task1, dr3.execution_date)
+        ti1_3.state = State.SCHEDULED
+        session.merge(ti1_2)
+        session.merge(ti1_3)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(0, len(res))
+
+        ti1_1.state = State.SCHEDULED
+        ti1_2.state = State.SCHEDULED
+        ti1_3.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti1_2)
+        session.merge(ti1_3)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(2, len(res))
+
+        ti1_1.state = State.RUNNING
+        ti1_2.state = State.SCHEDULED
+        ti1_3.state = State.SCHEDULED
+        session.merge(ti1_1)
+        session.merge(ti1_2)
+        session.merge(ti1_3)
+        session.commit()
+
+        res = scheduler._find_executable_task_instances(
+            dagbag,
+            states=[State.SCHEDULED],
+            session=session)
+
+        self.assertEqual(1, len(res))
+
     def test_change_state_for_executable_task_instances_no_tis(self):
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1169,7 +1264,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1201,7 +1296,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=2)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1237,7 +1332,7 @@ class SchedulerJobTest(unittest.TestCase):
         task_id_1 = 'dummy'
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1282,7 +1377,7 @@ class SchedulerJobTest(unittest.TestCase):
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=3)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
         task2 = DummyOperator(dag=dag, task_id=task_id_2)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         session = settings.Session()
@@ -1343,7 +1438,7 @@ class SchedulerJobTest(unittest.TestCase):
         dag = DAG(dag_id=dag_id, start_date=DEFAULT_DATE, concurrency=16)
         task1 = DummyOperator(dag=dag, task_id=task_id_1)
         task2 = DummyOperator(dag=dag, task_id=task_id_2)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         scheduler = SchedulerJob(**self.default_scheduler_args)
         scheduler.max_tis_per_query = 3
@@ -1410,16 +1505,18 @@ class SchedulerJobTest(unittest.TestCase):
         ti2.state = State.SCHEDULED
         session.commit()
 
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
         scheduler = SchedulerJob(num_runs=0, run_duration=0)
         scheduler._change_state_for_tis_without_dagrun(simple_dag_bag=dagbag,
                                                        old_states=[State.SCHEDULED, State.QUEUED],
                                                        new_state=State.NONE,
                                                        session=session)
 
+        ti = dr.get_task_instance(task_id='dummy', session=session)
         ti.refresh_from_db(session=session)
         self.assertEqual(ti.state, State.SCHEDULED)
 
+        ti2 = dr2.get_task_instance(task_id='dummy', session=session)
         ti2.refresh_from_db(session=session)
         self.assertEqual(ti2.state, State.SCHEDULED)
 
@@ -2042,7 +2139,7 @@ class SchedulerJobTest(unittest.TestCase):
         queue = []
         scheduler._process_task_instances(dag, queue=queue)
         self.assertEquals(len(queue), 2)
-        dagbag = SimpleDagBag([dag])
+        dagbag = self._make_simple_dag_bag([dag])
 
         # Recreated part of the scheduler here, to kick off tasks -> executor
         for ti_key in queue:

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index db5beca..a1de17d 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -517,6 +517,39 @@ class DagRunTest(unittest.TestCase):
         dr.update_state()
         self.assertEqual(dr.state, State.FAILED)
 
+    def test_dagrun_no_deadlock(self):
+        session = settings.Session()
+        dag = DAG('test_dagrun_no_deadlock',
+                  start_date=DEFAULT_DATE)
+        with dag:
+            op1 = DummyOperator(task_id='dop', depends_on_past=True)
+            op2 = DummyOperator(task_id='tc', task_concurrency=1)
+
+        dag.clear()
+        dr = dag.create_dagrun(run_id='test_dagrun_no_deadlock_1',
+                               state=State.RUNNING,
+                               execution_date=DEFAULT_DATE,
+                               start_date=DEFAULT_DATE)
+        dr2 = dag.create_dagrun(run_id='test_dagrun_no_deadlock_2',
+                                state=State.RUNNING,
+                                execution_date=DEFAULT_DATE + datetime.timedelta(days=1),
+                                start_date=DEFAULT_DATE + datetime.timedelta(days=1))
+        ti1_op1 = dr.get_task_instance(task_id='dop')
+        ti2_op1 = dr2.get_task_instance(task_id='dop')
+        ti2_op1 = dr.get_task_instance(task_id='tc')
+        ti2_op2 = dr.get_task_instance(task_id='tc')
+        ti1_op1.set_state(state=State.RUNNING, session=session)
+        dr.update_state()
+        dr2.update_state()
+        self.assertEqual(dr.state, State.RUNNING)
+        self.assertEqual(dr2.state, State.RUNNING)
+
+        ti2_op1.set_state(state=State.RUNNING, session=session)
+        dr.update_state()
+        dr2.update_state()
+        self.assertEqual(dr.state, State.RUNNING)
+        self.assertEqual(dr2.state, State.RUNNING)
+
     def test_get_task_instance_on_empty_dagrun(self):
         """
         Make sure that a proper value is returned when a dagrun has no task instances
@@ -1201,6 +1234,29 @@ class TaskInstanceTest(unittest.TestCase):
         ti = TI(
             task=task2, execution_date=datetime.datetime.now())
         self.assertFalse(ti._check_and_change_state_before_execution())
+
+    def test_get_num_running_task_instances(self):
+        session = settings.Session()
+
+        dag = models.DAG(dag_id='test_get_num_running_task_instances')
+        dag2 = models.DAG(dag_id='test_get_num_running_task_instances_dummy')
+        task = DummyOperator(task_id='task', dag=dag, start_date=DEFAULT_DATE)
+        task2 = DummyOperator(task_id='task', dag=dag2, start_date=DEFAULT_DATE)
+
+        ti1 = TI(task=task, execution_date=DEFAULT_DATE)
+        ti2 = TI(task=task, execution_date=DEFAULT_DATE + datetime.timedelta(days=1))
+        ti3 = TI(task=task2, execution_date=DEFAULT_DATE)
+        ti1.state = State.RUNNING
+        ti2.state = State.QUEUED
+        ti3.state = State.RUNNING
+        session.add(ti1)
+        session.add(ti2)
+        session.add(ti3)
+        session.commit()
+
+        self.assertEquals(1, ti1.get_num_running_task_instances(session=session))
+        self.assertEquals(1, ti2.get_num_running_task_instances(session=session))
+        self.assertEquals(1, ti3.get_num_running_task_instances(session=session))
         
 
 class ClearTasksTest(unittest.TestCase):

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/cfc2f73c/tests/ti_deps/deps/test_task_concurrency.py
----------------------------------------------------------------------
diff --git a/tests/ti_deps/deps/test_task_concurrency.py b/tests/ti_deps/deps/test_task_concurrency.py
new file mode 100644
index 0000000..77a5990
--- /dev/null
+++ b/tests/ti_deps/deps/test_task_concurrency.py
@@ -0,0 +1,51 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import unittest
+from datetime import datetime
+from mock import Mock
+
+from airflow.models import DAG, BaseOperator
+from airflow.ti_deps.dep_context import DepContext
+from airflow.ti_deps.deps.task_concurrency_dep import TaskConcurrencyDep
+from airflow.utils.state import State
+
+
+class TaskConcurrencyDepTest(unittest.TestCase):
+
+    def _get_task(self, **kwargs):
+        return BaseOperator(task_id='test_task', dag=DAG('test_dag'), **kwargs)
+
+    def test_not_task_concurrency(self):
+        task = self._get_task(start_date=datetime(2016, 1, 1))
+        dep_context = DepContext()
+        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+        self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+
+    def test_not_reached_concurrency(self):
+        task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=1)
+        dep_context = DepContext()
+        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+        ti.get_num_running_task_instances = lambda x: 0
+        self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+
+    def test_reached_concurrency(self):
+        task = self._get_task(start_date=datetime(2016, 1, 1), task_concurrency=2)
+        dep_context = DepContext()
+        ti = Mock(task=task, execution_date=datetime(2016, 1, 1))
+        ti.get_num_running_task_instances = lambda x: 1
+        self.assertTrue(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+        ti.get_num_running_task_instances = lambda x: 2
+        self.assertFalse(TaskConcurrencyDep().is_met(ti=ti, dep_context=dep_context))
+