You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ds...@apache.org on 2021/12/15 06:33:27 UTC

[airflow] branch main updated: Add and use `exactly_one` helper (#20184)

This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new caa05c8  Add and use `exactly_one` helper (#20184)
caa05c8 is described below

commit caa05c87b54dc60ca1897520a2791b7e8682380a
Author: Daniel Standish <15...@users.noreply.github.com>
AuthorDate: Tue Dec 14 22:32:55 2021 -0800

    Add and use `exactly_one` helper (#20184)
    
    The XOR operator `^` is not very readable, only works with True / False, and only works with two values.  E.g. if you need to test "exactly one of a, b, or c", you cannot do  a ^ b ^ c.  I add an "exactly_one" boolean helper to address these shortcomings.
---
 airflow/models/xcom.py      | 10 +++++-----
 airflow/utils/helpers.py    | 13 +++++++++++++
 tests/utils/test_helpers.py | 34 +++++++++++++++++++++++++++++++++-
 3 files changed, 51 insertions(+), 6 deletions(-)

diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 47b46b8..a3134e2 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -30,7 +30,7 @@ from sqlalchemy.orm import Query, Session, reconstructor, relationship
 from airflow.configuration import conf
 from airflow.models.base import COLLATION_ARGS, ID_LEN, Base
 from airflow.utils import timezone
-from airflow.utils.helpers import is_container
+from airflow.utils.helpers import exactly_one, is_container
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.sqlalchemy import UtcDateTime
@@ -134,7 +134,7 @@ class BaseXCom(Base, LoggingMixin):
         run_id: Optional[str] = None,
     ) -> None:
         """:sphinx-autoapi-skip:"""
-        if not (execution_date is None) ^ (run_id is None):
+        if not exactly_one(execution_date is not None, run_id is not None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
         if run_id:
@@ -225,7 +225,7 @@ class BaseXCom(Base, LoggingMixin):
         run_id: Optional[str] = None,
     ) -> Optional[Any]:
         """:sphinx-autoapi-skip:"""
-        if not (execution_date is None) ^ (run_id is None):
+        if not exactly_one(execution_date is not None, run_id is not None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
         if run_id is not None:
@@ -319,7 +319,7 @@ class BaseXCom(Base, LoggingMixin):
         run_id: Optional[str] = None,
     ) -> Query:
         """:sphinx-autoapi-skip:"""
-        if not (execution_date is None) ^ (run_id is None):
+        if not exactly_one(execution_date is not None, run_id is not None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
         filters = []
@@ -420,7 +420,7 @@ class BaseXCom(Base, LoggingMixin):
         if task_id is None:
             raise TypeError("clear() missing required argument: task_id")
 
-        if not (execution_date is None) ^ (run_id is None):
+        if not exactly_one(execution_date is not None, run_id is not None):
             raise ValueError("Exactly one of execution_date or run_id must be passed")
 
         query = session.query(cls).filter(
diff --git a/airflow/utils/helpers.py b/airflow/utils/helpers.py
index 42a0dfa..8d1e53b 100644
--- a/airflow/utils/helpers.py
+++ b/airflow/utils/helpers.py
@@ -286,3 +286,16 @@ def render_template_to_string(template: jinja2.Template, context: Context) -> st
 def render_template_as_native(template: jinja2.Template, context: Context) -> Any:
     """Shorthand to ``render_template(native=True)`` with better typing support."""
     return render_template(template, context, native=True)
+
+
+def exactly_one(*args) -> bool:
+    """
+    Returns True if exactly one of *args is "truthy", and False otherwise.
+
+    If user supplies an iterable, we raise ValueError and force them to unpack.
+    """
+    if is_container(args[0]):
+        raise ValueError(
+            "Not supported for iterable args. Use `*` to unpack your iterable in the function call."
+        )
+    return sum(map(bool, args)) == 1
diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py
index b0e2ca7..7b4424d 100644
--- a/tests/utils/test_helpers.py
+++ b/tests/utils/test_helpers.py
@@ -16,12 +16,19 @@
 # specific language governing permissions and limitations
 # under the License.
 import re
+from itertools import product
 
 import pytest
 
 from airflow import AirflowException
 from airflow.utils import helpers, timezone
-from airflow.utils.helpers import build_airflow_url_with_query, merge_dicts, validate_group_key, validate_key
+from airflow.utils.helpers import (
+    build_airflow_url_with_query,
+    exactly_one,
+    merge_dicts,
+    validate_group_key,
+    validate_key,
+)
 from tests.test_utils.config import conf_vars
 from tests.test_utils.db import clear_db_dags, clear_db_runs
 
@@ -230,3 +237,28 @@ class TestHelpers:
                 validate_group_key(key_id)
         else:
             validate_group_key(key_id)
+
+    def test_exactly_one(self):
+        """
+        Checks that when we set ``true_count`` elements to "truthy", and others to "falsy",
+        we get the expected return.
+
+        We check for both True / False, and truthy / falsy values 'a' and '', and verify that
+        they can safely be used in any combination.
+        """
+
+        def assert_exactly_one(true=0, truthy=0, false=0, falsy=0):
+            sample = []
+            for truth_value, num in [(True, true), (False, false), ('a', truthy), ('', falsy)]:
+                if num:
+                    sample.extend([truth_value] * num)
+            if sample:
+                expected = True if true + truthy == 1 else False
+                assert exactly_one(*sample) is expected
+
+        for row in product(range(4), range(4), range(4), range(4)):
+            assert_exactly_one(*row)
+
+    def test_exactly_one_should_fail(self):
+        with pytest.raises(ValueError):
+            exactly_one([True, False])