You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2018/01/18 15:10:28 UTC
incubator-airflow git commit: [AIRFLOW-192] Add weight_rule param to
BaseOperator
Repository: incubator-airflow
Updated Branches:
refs/heads/master fbba5ef7c -> dd2bc8cb9
[AIRFLOW-192] Add weight_rule param to BaseOperator
Improved task generation performance significantly
by using sets of
task_ids and dag_ids instead of lists when
calculating total priority
weight.
Closes #2941 from wongwill86/performance-latest
Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/dd2bc8cb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/dd2bc8cb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/dd2bc8cb
Branch: refs/heads/master
Commit: dd2bc8cb971d25087a35db16d12592f759ecbc6a
Parents: fbba5ef
Author: wongwill86 <wo...@gmail.com>
Authored: Thu Jan 18 16:09:40 2018 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Thu Jan 18 16:09:46 2018 +0100
----------------------------------------------------------------------
airflow/models.py | 102 +++++++++++++++++++++++++++++++-------
airflow/utils/weight_rule.py | 33 ++++++++++++
tests/models.py | 93 ++++++++++++++++++++++++++++++++++
3 files changed, 210 insertions(+), 18 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 08c4b52..c5233ec 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -19,7 +19,6 @@ from __future__ import unicode_literals
from future.standard_library import install_aliases
-install_aliases()
from builtins import str
from builtins import object, bytes
import copy
@@ -84,8 +83,11 @@ from airflow.utils.operator_resources import Resources
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.weight_rule import WeightRule
from airflow.utils.log.logging_mixin import LoggingMixin
+install_aliases()
+
Base = declarative_base()
ID_LEN = 250
XCOM_RETURN_KEY = 'return_value'
@@ -2073,6 +2075,29 @@ class BaseOperator(LoggingMixin):
This allows the executor to trigger higher priority tasks before
others when things get backed up.
:type priority_weight: int
+ :param weight_rule: weighting method used for the effective total
+ priority weight of the task. Options are:
+ ``{ downstream | upstream | absolute }`` default is ``downstream``
+ When set to ``downstream`` the effective weight of the task is the
+ aggregate sum of all downstream descendants. As a result, upstream
+ tasks will have higher weight and will be scheduled more aggressively
+ when using positive weight values. This is useful when you have
+ multiple dag run instances and desire to have all upstream tasks to
+ complete for all runs before each dag can continue processing
+ downstream tasks. When set to ``upstream`` the effective weight is the
+ aggregate sum of all upstream ancestors. This is the opposite where
+ downtream tasks have higher weight and will be scheduled more
+ aggressively when using positive weight values. This is useful when you
+ have multiple dag run instances and prefer to have each dag complete
+ before starting upstream tasks of other dags. When set to
+ ``absolute``, the effective weight is the exact ``priority_weight``
+ specified without additional weighting. You may want to do this when
+ you know exactly what priority weight each task should have.
+ Additionally, when set to ``absolute``, there is bonus effect of
+ significantly speeding up the task creation process as for very large
+ DAGS. Options can be set as string or using the constants defined in
+ the static class ``airflow.utils.WeightRule``
+ :type weight_rule: str
:param pool: the slot pool this task should run in, slot pools are a
way to limit concurrency for certain tasks
:type pool: str
@@ -2150,6 +2175,7 @@ class BaseOperator(LoggingMixin):
default_args=None,
adhoc=False,
priority_weight=1,
+ weight_rule=WeightRule.DOWNSTREAM,
queue=configuration.get('celery', 'default_queue'),
pool=None,
sla=None,
@@ -2190,7 +2216,7 @@ class BaseOperator(LoggingMixin):
"The trigger_rule must be one of {all_triggers},"
"'{d}.{t}'; received '{tr}'."
.format(all_triggers=TriggerRule.all_triggers,
- d=dag.dag_id, t=task_id, tr=trigger_rule))
+ d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))
self.trigger_rule = trigger_rule
self.depends_on_past = depends_on_past
@@ -2224,6 +2250,14 @@ class BaseOperator(LoggingMixin):
self.params = params or {} # Available in templates!
self.adhoc = adhoc
self.priority_weight = priority_weight
+ if not WeightRule.is_valid(weight_rule):
+ raise AirflowException(
+ "The weight_rule must be one of {all_weight_rules},"
+ "'{d}.{t}'; received '{tr}'."
+ .format(all_weight_rules=WeightRule.all_weight_rules,
+ d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
+ self.weight_rule = weight_rule
+
self.resources = Resources(**(resources or {}))
self.run_as_user = run_as_user
self.task_concurrency = task_concurrency
@@ -2402,10 +2436,19 @@ class BaseOperator(LoggingMixin):
@property
def priority_weight_total(self):
- return sum([
- t.priority_weight
- for t in self.get_flat_relatives(upstream=False)
- ]) + self.priority_weight
+ if self.weight_rule == WeightRule.ABSOLUTE:
+ return self.priority_weight
+ elif self.weight_rule == WeightRule.DOWNSTREAM:
+ upstream = False
+ elif self.weight_rule == WeightRule.UPSTREAM:
+ upstream = True
+ else:
+ upstream = False
+
+ return self.priority_weight + sum(
+ map(lambda task_id: self._dag.task_dict[task_id].priority_weight,
+ self.get_flat_relative_ids(upstream=upstream))
+ )
def pre_execute(self, context):
"""
@@ -2608,17 +2651,30 @@ class BaseOperator(LoggingMixin):
TI.execution_date <= end_date,
).order_by(TI.execution_date).all()
- def get_flat_relatives(self, upstream=False, l=None):
+ def get_flat_relative_ids(self, upstream=False, found_descendants=None):
+ """
+ Get a flat list of relatives' ids, either upstream or downstream.
+ """
+
+ if not found_descendants:
+ found_descendants = set()
+ relative_ids = self.get_direct_relative_ids(upstream)
+
+ for relative_id in relative_ids:
+ if relative_id not in found_descendants:
+ found_descendants.add(relative_id)
+ relative_task = self._dag.task_dict[relative_id]
+ relative_task.get_flat_relative_ids(upstream,
+ found_descendants)
+
+ return found_descendants
+
+ def get_flat_relatives(self, upstream=False):
"""
Get a flat list of relatives, either upstream or downstream.
"""
- if not l:
- l = []
- for t in self.get_direct_relatives(upstream):
- if not is_in(t, l):
- l.append(t)
- t.get_flat_relatives(upstream, l)
- return l
+ return list(map(lambda task_id: self._dag.task_dict[task_id],
+ self.get_flat_relative_ids(upstream)))
def detect_downstream_cycle(self, task=None):
"""
@@ -2664,6 +2720,16 @@ class BaseOperator(LoggingMixin):
self.log.info('Rendering template for %s', attr)
self.log.info(content)
+ def get_direct_relative_ids(self, upstream=False):
+ """
+ Get the direct relative ids to the current task, upstream or
+ downstream.
+ """
+ if upstream:
+ return self._upstream_task_ids
+ else:
+ return self._downstream_task_ids
+
def get_direct_relatives(self, upstream=False):
"""
Get the direct relatives to the current task, upstream or
@@ -2704,14 +2770,14 @@ class BaseOperator(LoggingMixin):
# relationships can only be set if the tasks share a single DAG. Tasks
# without a DAG are assigned to that DAG.
- dags = set(t.dag for t in [self] + task_list if t.has_dag())
+ dags = {t._dag.dag_id: t.dag for t in [self] + task_list if t.has_dag()}
if len(dags) > 1:
raise AirflowException(
'Tried to set relationships between tasks in '
- 'more than one DAG: {}'.format(dags))
+ 'more than one DAG: {}'.format(dags.values()))
elif len(dags) == 1:
- dag = list(dags)[0]
+ dag = dags.popitem()[1]
else:
raise AirflowException(
"Tried to create relationships between tasks that don't have "
@@ -4739,7 +4805,7 @@ class DagRun(Base, LoggingMixin):
ti.state = State.REMOVED
# check for missing tasks
- for task in dag.tasks:
+ for task in six.itervalues(dag.task_dict):
if task.adhoc:
continue
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/airflow/utils/weight_rule.py
----------------------------------------------------------------------
diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py
new file mode 100644
index 0000000..fde0d90
--- /dev/null
+++ b/airflow/utils/weight_rule.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import unicode_literals
+
+from builtins import object
+
+
+class WeightRule(object):
+ DOWNSTREAM = 'downstream'
+ UPSTREAM = 'upstream'
+ ABSOLUTE = 'absolute'
+
+ @classmethod
+ def is_valid(cls, weight_rule):
+ return weight_rule in cls.all_weight_rules()
+
+ @classmethod
+ def all_weight_rules(cls):
+ return [getattr(cls, attr)
+ for attr in dir(cls)
+ if not attr.startswith("__") and not callable(getattr(cls, attr))]
http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 3bab3cf..f0879eb 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -23,6 +23,8 @@ import os
import pendulum
import unittest
import time
+import six
+import re
from airflow import configuration, models, settings, AirflowException
from airflow.exceptions import AirflowSkipException
@@ -39,6 +41,7 @@ from airflow.operators.python_operator import PythonOperator
from airflow.operators.python_operator import ShortCircuitOperator
from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
from airflow.utils import timezone
+from airflow.utils.weight_rule import WeightRule
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from mock import patch
@@ -201,6 +204,96 @@ class DagTest(unittest.TestCase):
self.assertEquals(tuple(), dag.topological_sort())
+ def test_dag_task_priority_weight_total(self):
+ width = 5
+ depth = 5
+ weight = 5
+ pattern = re.compile('stage(\\d*).(\\d*)')
+ # Fully connected parallel tasks. i.e. every task at each parallel
+ # stage is dependent on every task in the previous stage.
+ # Default weight should be calculated using downstream descendants
+ with DAG('dag', start_date=DEFAULT_DATE,
+ default_args={'owner': 'owner1'}) as dag:
+ pipeline = [
+ [DummyOperator(
+ task_id='stage{}.{}'.format(i, j), priority_weight=weight)
+ for j in range(0, width)] for i in range(0, depth)
+ ]
+ for d, stage in enumerate(pipeline):
+ if d == 0:
+ continue
+ for current_task in stage:
+ for prev_task in pipeline[d - 1]:
+ current_task.set_upstream(prev_task)
+
+ for task in six.itervalues(dag.task_dict):
+ match = pattern.match(task.task_id)
+ task_depth = int(match.group(1))
+ # the sum of each stages after this task + itself
+ correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
+
+ calculated_weight = task.priority_weight_total
+ self.assertEquals(calculated_weight, correct_weight)
+
+ # Same test as above except use 'upstream' for weight calculation
+ weight = 3
+ with DAG('dag', start_date=DEFAULT_DATE,
+ default_args={'owner': 'owner1'}) as dag:
+ pipeline = [
+ [DummyOperator(
+ task_id='stage{}.{}'.format(i, j), priority_weight=weight,
+ weight_rule=WeightRule.UPSTREAM)
+ for j in range(0, width)] for i in range(0, depth)
+ ]
+ for d, stage in enumerate(pipeline):
+ if d == 0:
+ continue
+ for current_task in stage:
+ for prev_task in pipeline[d - 1]:
+ current_task.set_upstream(prev_task)
+
+ for task in six.itervalues(dag.task_dict):
+ match = pattern.match(task.task_id)
+ task_depth = int(match.group(1))
+ # the sum of each stages after this task + itself
+ correct_weight = ((task_depth) * width + 1) * weight
+
+ calculated_weight = task.priority_weight_total
+ self.assertEquals(calculated_weight, correct_weight)
+
+ # Same test as above except use 'absolute' for weight calculation
+ weight = 10
+ with DAG('dag', start_date=DEFAULT_DATE,
+ default_args={'owner': 'owner1'}) as dag:
+ pipeline = [
+ [DummyOperator(
+ task_id='stage{}.{}'.format(i, j), priority_weight=weight,
+ weight_rule=WeightRule.ABSOLUTE)
+ for j in range(0, width)] for i in range(0, depth)
+ ]
+ for d, stage in enumerate(pipeline):
+ if d == 0:
+ continue
+ for current_task in stage:
+ for prev_task in pipeline[d - 1]:
+ current_task.set_upstream(prev_task)
+
+ for task in six.itervalues(dag.task_dict):
+ match = pattern.match(task.task_id)
+ task_depth = int(match.group(1))
+ # the sum of each stages after this task + itself
+ correct_weight = weight
+
+ calculated_weight = task.priority_weight_total
+ self.assertEquals(calculated_weight, correct_weight)
+
+ # Test if we enter an invalid weight rule
+ with DAG('dag', start_date=DEFAULT_DATE,
+ default_args={'owner': 'owner1'}) as dag:
+ with self.assertRaises(AirflowException):
+ DummyOperator(task_id='should_fail', weight_rule='no rule')
+
+
def test_get_num_task_instances(self):
test_dag_id = 'test_get_num_task_instances_dag'
test_task_id = 'task_1'