You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/02/20 11:07:39 UTC

[airflow] branch main updated: Refactor TriggerRule & WeightRule classes to inherit from Enum (#21264)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9ad4de8  Refactor TriggerRule & WeightRule classes to inherit from Enum (#21264)
9ad4de8 is described below

commit 9ad4de835cbbd296b9dbd1ff0ea88c1cd0050263
Author: Chenglong Yan <al...@gmail.com>
AuthorDate: Sun Feb 20 19:06:33 2022 +0800

    Refactor TriggerRule & WeightRule classes to inherit from Enum (#21264)
    
    closes: #19905
    related: #5302,#18627
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
    
    Co-authored-by: Tzu-ping Chung <ur...@gmail.com>
---
 airflow/utils/trigger_rule.py    | 21 ++++++++-------------
 airflow/utils/weight_rule.py     | 22 ++++++++++------------
 tests/utils/test_trigger_rule.py |  5 +++++
 tests/utils/test_weight_rule.py  |  5 +++++
 4 files changed, 28 insertions(+), 25 deletions(-)

diff --git a/airflow/utils/trigger_rule.py b/airflow/utils/trigger_rule.py
index 890bdc7..44fc8b51 100644
--- a/airflow/utils/trigger_rule.py
+++ b/airflow/utils/trigger_rule.py
@@ -15,11 +15,11 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+from enum import Enum
 from typing import Set
 
 
-class TriggerRule:
+class TriggerRule(str, Enum):
     """Class with task's trigger rules."""
 
     ALL_SUCCESS = 'all_success'
@@ -34,20 +34,15 @@ class TriggerRule:
     ALWAYS = 'always'
     NONE_FAILED_MIN_ONE_SUCCESS = "none_failed_min_one_success"
 
-    _ALL_TRIGGER_RULES: Set[str] = set()
-
     @classmethod
-    def is_valid(cls, trigger_rule):
+    def is_valid(cls, trigger_rule: str) -> bool:
         """Validates a trigger rule."""
         return trigger_rule in cls.all_triggers()
 
     @classmethod
-    def all_triggers(cls):
+    def all_triggers(cls) -> Set[str]:
         """Returns all trigger rules."""
-        if not cls._ALL_TRIGGER_RULES:
-            cls._ALL_TRIGGER_RULES = {
-                getattr(cls, attr)
-                for attr in dir(cls)
-                if not attr.startswith("_") and not callable(getattr(cls, attr))
-            }
-        return cls._ALL_TRIGGER_RULES
+        return set(cls.__members__.values())
+
+    def __str__(self) -> str:
+        return self.value
diff --git a/airflow/utils/weight_rule.py b/airflow/utils/weight_rule.py
index 002229c..f4f9cc3 100644
--- a/airflow/utils/weight_rule.py
+++ b/airflow/utils/weight_rule.py
@@ -15,31 +15,29 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
+from enum import Enum
 from typing import Set
 
+from airflow.compat.functools import cache
+
 
-class WeightRule:
+class WeightRule(str, Enum):
     """Weight rules."""
 
     DOWNSTREAM = 'downstream'
     UPSTREAM = 'upstream'
     ABSOLUTE = 'absolute'
 
-    _ALL_WEIGHT_RULES: Set[str] = set()
-
     @classmethod
-    def is_valid(cls, weight_rule):
+    def is_valid(cls, weight_rule: str) -> bool:
         """Check if weight rule is valid."""
         return weight_rule in cls.all_weight_rules()
 
     @classmethod
+    @cache
     def all_weight_rules(cls) -> Set[str]:
         """Returns all weight rules"""
-        if not cls._ALL_WEIGHT_RULES:
-            cls._ALL_WEIGHT_RULES = {
-                getattr(cls, attr)
-                for attr in dir(cls)
-                if not attr.startswith("_") and not callable(getattr(cls, attr))
-            }
-        return cls._ALL_WEIGHT_RULES
+        return set(cls.__members__.values())
+
+    def __str__(self) -> str:
+        return self.value
diff --git a/tests/utils/test_trigger_rule.py b/tests/utils/test_trigger_rule.py
index 5132c34..05ff9da 100644
--- a/tests/utils/test_trigger_rule.py
+++ b/tests/utils/test_trigger_rule.py
@@ -18,6 +18,8 @@
 
 import unittest
 
+import pytest
+
 from airflow.utils.trigger_rule import TriggerRule
 
 
@@ -35,3 +37,6 @@ class TestTriggerRule(unittest.TestCase):
         assert TriggerRule.is_valid(TriggerRule.ALWAYS)
         assert TriggerRule.is_valid(TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS)
         assert len(TriggerRule.all_triggers()) == 11
+
+        with pytest.raises(ValueError):
+            TriggerRule("NOT_EXIST_TRIGGER_RULE")
diff --git a/tests/utils/test_weight_rule.py b/tests/utils/test_weight_rule.py
index cad142d..440b213 100644
--- a/tests/utils/test_weight_rule.py
+++ b/tests/utils/test_weight_rule.py
@@ -18,6 +18,8 @@
 
 import unittest
 
+import pytest
+
 from airflow.utils.weight_rule import WeightRule
 
 
@@ -27,3 +29,6 @@ class TestWeightRule(unittest.TestCase):
         assert WeightRule.is_valid(WeightRule.UPSTREAM)
         assert WeightRule.is_valid(WeightRule.ABSOLUTE)
         assert len(WeightRule.all_weight_rules()) == 3
+
+        with pytest.raises(ValueError):
+            WeightRule("NOT_EXIST_WEIGHT_RULE")