You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by bo...@apache.org on 2018/01/18 15:10:28 UTC

incubator-airflow git commit: [AIRFLOW-192] Add weight_rule param to BaseOperator

Repository: incubator-airflow
Updated Branches:
  refs/heads/master fbba5ef7c -> dd2bc8cb9


[AIRFLOW-192] Add weight_rule param to BaseOperator

Improved task generation performance significantly
by using sets of
task_ids and dag_ids instead of lists when
calculating total priority
weight.

Closes #2941 from wongwill86/performance-latest


Project: http://git-wip-us.apache.org/repos/asf/incubator-airflow/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-airflow/commit/dd2bc8cb
Tree: http://git-wip-us.apache.org/repos/asf/incubator-airflow/tree/dd2bc8cb
Diff: http://git-wip-us.apache.org/repos/asf/incubator-airflow/diff/dd2bc8cb

Branch: refs/heads/master
Commit: dd2bc8cb971d25087a35db16d12592f759ecbc6a
Parents: fbba5ef
Author: wongwill86 <wo...@gmail.com>
Authored: Thu Jan 18 16:09:40 2018 +0100
Committer: Bolke de Bruin <bo...@xs4all.nl>
Committed: Thu Jan 18 16:09:46 2018 +0100

----------------------------------------------------------------------
 airflow/models.py            | 102 +++++++++++++++++++++++++++++++-------
 airflow/utils/weight_rule.py |  33 ++++++++++++
 tests/models.py              |  93 ++++++++++++++++++++++++++++++++++
 3 files changed, 210 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/airflow/models.py
----------------------------------------------------------------------
diff --git a/airflow/models.py b/airflow/models.py
index 08c4b52..c5233ec 100755
--- a/airflow/models.py
+++ b/airflow/models.py
@@ -19,7 +19,6 @@ from __future__ import unicode_literals
 
 from future.standard_library import install_aliases
 
-install_aliases()
 from builtins import str
 from builtins import object, bytes
 import copy
@@ -84,8 +83,11 @@ from airflow.utils.operator_resources import Resources
 from airflow.utils.state import State
 from airflow.utils.timeout import timeout
 from airflow.utils.trigger_rule import TriggerRule
+from airflow.utils.weight_rule import WeightRule
 from airflow.utils.log.logging_mixin import LoggingMixin
 
+install_aliases()
+
 Base = declarative_base()
 ID_LEN = 250
 XCOM_RETURN_KEY = 'return_value'
@@ -2073,6 +2075,29 @@ class BaseOperator(LoggingMixin):
         This allows the executor to trigger higher priority tasks before
         others when things get backed up.
     :type priority_weight: int
+    :param weight_rule: weighting method used for the effective total
+        priority weight of the task. Options are:
+        ``{ downstream | upstream | absolute }`` default is ``downstream``
+        When set to ``downstream`` the effective weight of the task is the
+        aggregate sum of all downstream descendants. As a result, upstream
+        tasks will have higher weight and will be scheduled more aggressively
+        when using positive weight values. This is useful when you have
+        multiple dag run instances and desire to have all upstream tasks to
+        complete for all runs before each dag can continue processing
+        downstream tasks. When set to ``upstream`` the effective weight is the
+        aggregate sum of all upstream ancestors. This is the opposite where
+        downtream tasks have higher weight and will be scheduled more
+        aggressively when using positive weight values. This is useful when you
+        have multiple dag run instances and prefer to have each dag complete
+        before starting upstream tasks of other dags.  When set to
+        ``absolute``, the effective weight is the exact ``priority_weight``
+        specified without additional weighting. You may want to do this when
+        you know exactly what priority weight each task should have.
+        Additionally, when set to ``absolute``, there is bonus effect of
+        significantly speeding up the task creation process as for very large
+        DAGS. Options can be set as string or using the constants defined in
+        the static class ``airflow.utils.WeightRule``
+    :type weight_rule: str
     :param pool: the slot pool this task should run in, slot pools are a
         way to limit concurrency for certain tasks
     :type pool: str
@@ -2150,6 +2175,7 @@ class BaseOperator(LoggingMixin):
             default_args=None,
             adhoc=False,
             priority_weight=1,
+            weight_rule=WeightRule.DOWNSTREAM,
             queue=configuration.get('celery', 'default_queue'),
             pool=None,
             sla=None,
@@ -2190,7 +2216,7 @@ class BaseOperator(LoggingMixin):
                 "The trigger_rule must be one of {all_triggers},"
                 "'{d}.{t}'; received '{tr}'."
                 .format(all_triggers=TriggerRule.all_triggers,
-                        d=dag.dag_id, t=task_id, tr=trigger_rule))
+                        d=dag.dag_id if dag else "", t=task_id, tr=trigger_rule))
 
         self.trigger_rule = trigger_rule
         self.depends_on_past = depends_on_past
@@ -2224,6 +2250,14 @@ class BaseOperator(LoggingMixin):
         self.params = params or {}  # Available in templates!
         self.adhoc = adhoc
         self.priority_weight = priority_weight
+        if not WeightRule.is_valid(weight_rule):
+            raise AirflowException(
+                "The weight_rule must be one of {all_weight_rules},"
+                "'{d}.{t}'; received '{tr}'."
+                .format(all_weight_rules=WeightRule.all_weight_rules,
+                        d=dag.dag_id if dag else "", t=task_id, tr=weight_rule))
+        self.weight_rule = weight_rule
+
         self.resources = Resources(**(resources or {}))
         self.run_as_user = run_as_user
         self.task_concurrency = task_concurrency
@@ -2402,10 +2436,19 @@ class BaseOperator(LoggingMixin):
 
     @property
     def priority_weight_total(self):
-        return sum([
-            t.priority_weight
-            for t in self.get_flat_relatives(upstream=False)
-        ]) + self.priority_weight
+        if self.weight_rule == WeightRule.ABSOLUTE:
+            return self.priority_weight
+        elif self.weight_rule == WeightRule.DOWNSTREAM:
+            upstream = False
+        elif self.weight_rule == WeightRule.UPSTREAM:
+            upstream = True
+        else:
+            upstream = False
+
+        return self.priority_weight + sum(
+            map(lambda task_id: self._dag.task_dict[task_id].priority_weight,
+                self.get_flat_relative_ids(upstream=upstream))
+        )
 
     def pre_execute(self, context):
         """
@@ -2608,17 +2651,30 @@ class BaseOperator(LoggingMixin):
             TI.execution_date <= end_date,
         ).order_by(TI.execution_date).all()
 
-    def get_flat_relatives(self, upstream=False, l=None):
+    def get_flat_relative_ids(self, upstream=False, found_descendants=None):
+        """
+        Get a flat list of relatives' ids, either upstream or downstream.
+        """
+
+        if not found_descendants:
+            found_descendants = set()
+        relative_ids = self.get_direct_relative_ids(upstream)
+
+        for relative_id in relative_ids:
+            if relative_id not in found_descendants:
+                found_descendants.add(relative_id)
+                relative_task = self._dag.task_dict[relative_id]
+                relative_task.get_flat_relative_ids(upstream,
+                                                    found_descendants)
+
+        return found_descendants
+
+    def get_flat_relatives(self, upstream=False):
         """
         Get a flat list of relatives, either upstream or downstream.
         """
-        if not l:
-            l = []
-        for t in self.get_direct_relatives(upstream):
-            if not is_in(t, l):
-                l.append(t)
-                t.get_flat_relatives(upstream, l)
-        return l
+        return list(map(lambda task_id: self._dag.task_dict[task_id],
+                        self.get_flat_relative_ids(upstream)))
 
     def detect_downstream_cycle(self, task=None):
         """
@@ -2664,6 +2720,16 @@ class BaseOperator(LoggingMixin):
                 self.log.info('Rendering template for %s', attr)
                 self.log.info(content)
 
+    def get_direct_relative_ids(self, upstream=False):
+        """
+        Get the direct relative ids to the current task, upstream or
+        downstream.
+        """
+        if upstream:
+            return self._upstream_task_ids
+        else:
+            return self._downstream_task_ids
+
     def get_direct_relatives(self, upstream=False):
         """
         Get the direct relatives to the current task, upstream or
@@ -2704,14 +2770,14 @@ class BaseOperator(LoggingMixin):
 
         # relationships can only be set if the tasks share a single DAG. Tasks
         # without a DAG are assigned to that DAG.
-        dags = set(t.dag for t in [self] + task_list if t.has_dag())
+        dags = {t._dag.dag_id: t.dag for t in [self] + task_list if t.has_dag()}
 
         if len(dags) > 1:
             raise AirflowException(
                 'Tried to set relationships between tasks in '
-                'more than one DAG: {}'.format(dags))
+                'more than one DAG: {}'.format(dags.values()))
         elif len(dags) == 1:
-            dag = list(dags)[0]
+            dag = dags.popitem()[1]
         else:
             raise AirflowException(
                 "Tried to create relationships between tasks that don't have "
@@ -4739,7 +4805,7 @@ class DagRun(Base, LoggingMixin):
                     ti.state = State.REMOVED
 
         # check for missing tasks
-        for task in dag.tasks:
+        for task in six.itervalues(dag.task_dict):
             if task.adhoc:
                 continue
 

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/airflow/utils/weight_rule.py
----------------------------------------------------------------------
diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py
new file mode 100644
index 0000000..fde0d90
--- /dev/null
+++ b/airflow/utils/weight_rule.py
@@ -0,0 +1,33 @@
+# -*- coding: utf-8 -*-
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+from __future__ import unicode_literals
+
+from builtins import object
+
+
+class WeightRule(object):
+    DOWNSTREAM = 'downstream'
+    UPSTREAM = 'upstream'
+    ABSOLUTE = 'absolute'
+
+    @classmethod
+    def is_valid(cls, weight_rule):
+        return weight_rule in cls.all_weight_rules()
+
+    @classmethod
+    def all_weight_rules(cls):
+        return [getattr(cls, attr)
+                for attr in dir(cls)
+                if not attr.startswith("__") and not callable(getattr(cls, attr))]

http://git-wip-us.apache.org/repos/asf/incubator-airflow/blob/dd2bc8cb/tests/models.py
----------------------------------------------------------------------
diff --git a/tests/models.py b/tests/models.py
index 3bab3cf..f0879eb 100644
--- a/tests/models.py
+++ b/tests/models.py
@@ -23,6 +23,8 @@ import os
 import pendulum
 import unittest
 import time
+import six
+import re
 
 from airflow import configuration, models, settings, AirflowException
 from airflow.exceptions import AirflowSkipException
@@ -39,6 +41,7 @@ from airflow.operators.python_operator import PythonOperator
 from airflow.operators.python_operator import ShortCircuitOperator
 from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep
 from airflow.utils import timezone
+from airflow.utils.weight_rule import WeightRule
 from airflow.utils.state import State
 from airflow.utils.trigger_rule import TriggerRule
 from mock import patch
@@ -201,6 +204,96 @@ class DagTest(unittest.TestCase):
 
         self.assertEquals(tuple(), dag.topological_sort())
 
+    def test_dag_task_priority_weight_total(self):
+        width = 5
+        depth = 5
+        weight = 5
+        pattern = re.compile('stage(\\d*).(\\d*)')
+        # Fully connected parallel tasks. i.e. every task at each parallel
+        # stage is dependent on every task in the previous stage.
+        # Default weight should be calculated using downstream descendants
+        with DAG('dag', start_date=DEFAULT_DATE,
+                 default_args={'owner': 'owner1'}) as dag:
+            pipeline = [
+                [DummyOperator(
+                    task_id='stage{}.{}'.format(i, j), priority_weight=weight)
+                    for j in range(0, width)] for i in range(0, depth)
+            ]
+            for d, stage in enumerate(pipeline):
+                if d == 0:
+                    continue
+                for current_task in stage:
+                    for prev_task in pipeline[d - 1]:
+                        current_task.set_upstream(prev_task)
+
+            for task in six.itervalues(dag.task_dict):
+                match = pattern.match(task.task_id)
+                task_depth = int(match.group(1))
+                # the sum of each stages after this task + itself
+                correct_weight = ((depth - (task_depth + 1)) * width + 1) * weight
+
+                calculated_weight = task.priority_weight_total
+                self.assertEquals(calculated_weight, correct_weight)
+
+        # Same test as above except use 'upstream' for weight calculation
+        weight = 3
+        with DAG('dag', start_date=DEFAULT_DATE,
+                 default_args={'owner': 'owner1'}) as dag:
+            pipeline = [
+                [DummyOperator(
+                    task_id='stage{}.{}'.format(i, j), priority_weight=weight,
+                    weight_rule=WeightRule.UPSTREAM)
+                    for j in range(0, width)] for i in range(0, depth)
+            ]
+            for d, stage in enumerate(pipeline):
+                if d == 0:
+                    continue
+                for current_task in stage:
+                    for prev_task in pipeline[d - 1]:
+                        current_task.set_upstream(prev_task)
+
+            for task in six.itervalues(dag.task_dict):
+                match = pattern.match(task.task_id)
+                task_depth = int(match.group(1))
+                # the sum of each stages after this task + itself
+                correct_weight = ((task_depth) * width + 1) * weight
+
+                calculated_weight = task.priority_weight_total
+                self.assertEquals(calculated_weight, correct_weight)
+
+        # Same test as above except use 'absolute' for weight calculation
+        weight = 10
+        with DAG('dag', start_date=DEFAULT_DATE,
+                 default_args={'owner': 'owner1'}) as dag:
+            pipeline = [
+                [DummyOperator(
+                    task_id='stage{}.{}'.format(i, j), priority_weight=weight,
+                    weight_rule=WeightRule.ABSOLUTE)
+                    for j in range(0, width)] for i in range(0, depth)
+            ]
+            for d, stage in enumerate(pipeline):
+                if d == 0:
+                    continue
+                for current_task in stage:
+                    for prev_task in pipeline[d - 1]:
+                        current_task.set_upstream(prev_task)
+
+            for task in six.itervalues(dag.task_dict):
+                match = pattern.match(task.task_id)
+                task_depth = int(match.group(1))
+                # the sum of each stages after this task + itself
+                correct_weight = weight
+
+                calculated_weight = task.priority_weight_total
+                self.assertEquals(calculated_weight, correct_weight)
+
+        # Test if we enter an invalid weight rule
+        with DAG('dag', start_date=DEFAULT_DATE,
+                 default_args={'owner': 'owner1'}) as dag:
+            with self.assertRaises(AirflowException):
+                DummyOperator(task_id='should_fail', weight_rule='no rule')
+
+
     def test_get_num_task_instances(self):
         test_dag_id = 'test_get_num_task_instances_dag'
         test_task_id = 'task_1'