You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/02/20 11:07:39 UTC
[airflow] branch main updated: Refactor TriggerRule & WeightRule classes to inherit from Enum (#21264)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 9ad4de8 Refactor TriggerRule & WeightRule classes to inherit from Enum (#21264)
9ad4de8 is described below
commit 9ad4de835cbbd296b9dbd1ff0ea88c1cd0050263
Author: Chenglong Yan <al...@gmail.com>
AuthorDate: Sun Feb 20 19:06:33 2022 +0800
Refactor TriggerRule & WeightRule classes to inherit from Enum (#21264)
closes: #19905
related: #5302,#18627
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
airflow/utils/trigger_rule.py | 21 ++++++++-------------
airflow/utils/weight_rule.py | 22 ++++++++++------------
tests/utils/test_trigger_rule.py | 5 +++++
tests/utils/test_weight_rule.py | 5 +++++
4 files changed, 28 insertions(+), 25 deletions(-)
diff --git a/airflow/utils/trigger_rule.py b/airflow/utils/trigger_rule.py
index 890bdc7..44fc8b51 100644
--- a/airflow/utils/trigger_rule.py
+++ b/airflow/utils/trigger_rule.py
@@ -15,11 +15,11 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from enum import Enum
from typing import Set
-class TriggerRule:
+class TriggerRule(str, Enum):
"""Class with task's trigger rules."""
ALL_SUCCESS = 'all_success'
@@ -34,20 +34,15 @@ class TriggerRule:
ALWAYS = 'always'
NONE_FAILED_MIN_ONE_SUCCESS = "none_failed_min_one_success"
- _ALL_TRIGGER_RULES: Set[str] = set()
-
@classmethod
- def is_valid(cls, trigger_rule):
+ def is_valid(cls, trigger_rule: str) -> bool:
"""Validates a trigger rule."""
return trigger_rule in cls.all_triggers()
@classmethod
- def all_triggers(cls):
+ def all_triggers(cls) -> Set[str]:
"""Returns all trigger rules."""
- if not cls._ALL_TRIGGER_RULES:
- cls._ALL_TRIGGER_RULES = {
- getattr(cls, attr)
- for attr in dir(cls)
- if not attr.startswith("_") and not callable(getattr(cls, attr))
- }
- return cls._ALL_TRIGGER_RULES
+ return set(cls.__members__.values())
+
+ def __str__(self) -> str:
+ return self.value
diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py
index 002229c..f4f9cc3 100644
--- a/airflow/utils/weight_rule.py
+++ b/airflow/utils/weight_rule.py
@@ -15,31 +15,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
+from enum import Enum
from typing import Set
+from airflow.compat.functools import cache
+
-class WeightRule:
+class WeightRule(str, Enum):
"""Weight rules."""
DOWNSTREAM = 'downstream'
UPSTREAM = 'upstream'
ABSOLUTE = 'absolute'
- _ALL_WEIGHT_RULES: Set[str] = set()
-
@classmethod
- def is_valid(cls, weight_rule):
+ def is_valid(cls, weight_rule: str) -> bool:
"""Check if weight rule is valid."""
return weight_rule in cls.all_weight_rules()
@classmethod
+ @cache
def all_weight_rules(cls) -> Set[str]:
"""Returns all weight rules"""
- if not cls._ALL_WEIGHT_RULES:
- cls._ALL_WEIGHT_RULES = {
- getattr(cls, attr)
- for attr in dir(cls)
- if not attr.startswith("_") and not callable(getattr(cls, attr))
- }
- return cls._ALL_WEIGHT_RULES
+ return set(cls.__members__.values())
+
+ def __str__(self) -> str:
+ return self.value
diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py
index 5132c34..05ff9da 100644
--- a/tests/utils/test_trigger_rule.py
+++ b/tests/utils/test_trigger_rule.py
@@ -18,6 +18,8 @@
import unittest
+import pytest
+
from airflow.utils.trigger_rule import TriggerRule
@@ -35,3 +37,6 @@ class TestTriggerRule(unittest.TestCase):
assert TriggerRule.is_valid(TriggerRule.ALWAYS)
assert TriggerRule.is_valid(TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
assert len(TriggerRule.all_triggers()) == 11
+
+ with pytest.raises(ValueError):
+ TriggerRule("NOT_EXIST_TRIGGER_RULE")
diff --git a/tests/utils/test_weight_rule.py b/tests/utils/test_weight_rule.py
index cad142d..440b213 100644
--- a/tests/utils/test_weight_rule.py
+++ b/tests/utils/test_weight_rule.py
@@ -18,6 +18,8 @@
import unittest
+import pytest
+
from airflow.utils.weight_rule import WeightRule
@@ -27,3 +29,6 @@ class TestWeightRule(unittest.TestCase):
assert WeightRule.is_valid(WeightRule.UPSTREAM)
assert WeightRule.is_valid(WeightRule.ABSOLUTE)
assert len(WeightRule.all_weight_rules()) == 3
+
+ with pytest.raises(ValueError):
+ WeightRule("NOT_EXIST_WEIGHT_RULE")