You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/03/10 22:44:21 UTC

[airflow] branch master updated: Add new datetime branch operator (#11964)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e37a11  Add new datetime branch operator (#11964)
1e37a11 is described below

commit 1e37a11e00c065e2dafa93dec9df5f024d0aabe5
Author: Tomás Farías Santana <to...@gmail.com>
AuthorDate: Wed Mar 10 23:44:08 2021 +0100

    Add new datetime branch operator (#11964)
    
    closes: #11929
    
    This PR includes a new datetime branching operator: the current date and time, as given by datetime.datetime.now is compared against target datetime attributes, like year or hour, to decide which task id branch to take.
---
 .../example_datetime_branch_operator.py            |  83 +++++++
 airflow/operators/datetime_branch.py               | 108 +++++++++
 .../howto/operator/datetime_branch.rst             |  39 ++++
 docs/apache-airflow/howto/operator/index.rst       |   1 +
 tests/operators/test_datetime_branch.py            | 251 +++++++++++++++++++++
 5 files changed, 482 insertions(+)

diff --git a/airflow/example_dags/example_datetime_branch_operator.py b/airflow/example_dags/example_datetime_branch_operator.py
new file mode 100644
index 0000000..58b1a0a
--- /dev/null
+++ b/airflow/example_dags/example_datetime_branch_operator.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Example DAG demonstrating the usage of DateTimeBranchOperator with datetime as well as time objects as
+targets.
+"""
+import datetime
+
+from airflow import DAG
+from airflow.operators.datetime_branch import DateTimeBranchOperator
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils.dates import days_ago
+
+args = {
+    "owner": "airflow",
+}
+
+dag = DAG(
+    dag_id="example_datetime_branch_operator",
+    start_date=days_ago(2),
+    default_args=args,
+    tags=["example"],
+    schedule_interval="@daily",
+)
+
+# [START howto_operator_datetime_branch]
+dummy_task_1 = DummyOperator(task_id='date_in_range', dag=dag)
+dummy_task_2 = DummyOperator(task_id='date_outside_range', dag=dag)
+
+cond1 = DateTimeBranchOperator(
+    task_id='datetime_branch',
+    follow_task_ids_if_true=['date_in_range'],
+    follow_task_ids_if_false=['date_outside_range'],
+    target_upper=datetime.datetime(2020, 10, 10, 15, 0, 0),
+    target_lower=datetime.datetime(2020, 10, 10, 14, 0, 0),
+    dag=dag,
+)
+
+# Run dummy_task_1 if cond1 executes between 2020-10-10 14:00:00 and 2020-10-10 15:00:00
+cond1 >> [dummy_task_1, dummy_task_2]
+# [END howto_operator_datetime_branch]
+
+
+dag = DAG(
+    dag_id="example_datetime_branch_operator_2",
+    start_date=days_ago(2),
+    default_args=args,
+    tags=["example"],
+    schedule_interval="@daily",
+)
+# [START howto_operator_datetime_branch_next_day]
+dummy_task_1 = DummyOperator(task_id='date_in_range', dag=dag)
+dummy_task_2 = DummyOperator(task_id='date_outside_range', dag=dag)
+
+cond2 = DateTimeBranchOperator(
+    task_id='datetime_branch',
+    follow_task_ids_if_true=['date_in_range'],
+    follow_task_ids_if_false=['date_outside_range'],
+    target_upper=datetime.time(0, 0, 0),
+    target_lower=datetime.time(15, 0, 0),
+    dag=dag,
+)
+
+# Since target_lower happens after target_upper, target_upper will be moved to the following day
+# Run dummy_task_1 if cond2 executes between 15:00:00, and 00:00:00 of the following day
+cond2 >> [dummy_task_1, dummy_task_2]
+# [END howto_operator_datetime_branch_next_day]
diff --git a/airflow/operators/datetime_branch.py b/airflow/operators/datetime_branch.py
new file mode 100644
index 0000000..868bd23
--- /dev/null
+++ b/airflow/operators/datetime_branch.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+from typing import Dict, Iterable, Union
+
+from airflow.exceptions import AirflowException
+from airflow.operators.branch_operator import BaseBranchOperator
+from airflow.utils import timezone
+from airflow.utils.decorators import apply_defaults
+
+
+class DateTimeBranchOperator(BaseBranchOperator):
+    """
+    Branches into one of two lists of tasks depending on the current datetime.
+
+    True branch will be returned when `datetime.datetime.now()` falls below
+    ``target_upper`` and above ``target_lower``.
+
+    :param follow_task_ids_if_true: task id or task ids to follow if
+        ``datetime.datetime.now()`` falls above target_lower and below ``target_upper``.
+    :type follow_task_ids_if_true: str or list[str]
+    :param follow_task_ids_if_false: task id or task ids to follow if
+        ``datetime.datetime.now()`` falls below target_lower or above ``target_upper``.
+    :type follow_task_ids_if_false: str or list[str]
+    :param target_lower: target lower bound.
+    :type target_lower: Optional[datetime.datetime]
+    :param target_upper: target upper bound.
+    :type target_upper: Optional[datetime.datetime]
+    :param use_task_execution_date: If ``True``, uses task's execution day to compare with targets.
+        Execution date is useful for backfilling. If ``False``, uses system's date.
+    :type use_task_execution_date: bool
+    """
+
+    @apply_defaults
+    def __init__(
+        self,
+        *,
+        follow_task_ids_if_true: Union[str, Iterable[str]],
+        follow_task_ids_if_false: Union[str, Iterable[str]],
+        target_lower: Union[datetime.datetime, datetime.time, None],
+        target_upper: Union[datetime.datetime, datetime.time, None],
+        use_task_execution_date: bool = False,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        if target_lower is None and target_upper is None:
+            raise AirflowException(
+                "Both target_upper and target_lower are None. At least one "
+                "must be defined to be compared to the current datetime"
+            )
+
+        self.target_lower = target_lower
+        self.target_upper = target_upper
+        self.follow_task_ids_if_true = follow_task_ids_if_true
+        self.follow_task_ids_if_false = follow_task_ids_if_false
+        self.use_task_execution_date = use_task_execution_date
+
+    def choose_branch(self, context: Dict) -> Union[str, Iterable[str]]:
+        if self.use_task_execution_date is True:
+            now = timezone.make_naive(context["execution_date"], self.dag.timezone)
+        else:
+            now = timezone.make_naive(timezone.utcnow(), self.dag.timezone)
+
+        lower, upper = target_times_as_dates(now, self.target_lower, self.target_upper)
+        if upper is not None and upper < now:
+            return self.follow_task_ids_if_false
+
+        if lower is not None and lower > now:
+            return self.follow_task_ids_if_false
+
+        return self.follow_task_ids_if_true
+
+
+def target_times_as_dates(
+    base_date: datetime.datetime,
+    lower: Union[datetime.datetime, datetime.time, None],
+    upper: Union[datetime.datetime, datetime.time, None],
+):
+    """Ensures upper and lower time targets are datetimes by combining them with base_date"""
+    if isinstance(lower, datetime.datetime) and isinstance(upper, datetime.datetime):
+        return lower, upper
+
+    if lower is not None and isinstance(lower, datetime.time):
+        lower = datetime.datetime.combine(base_date, lower)
+    if upper is not None and isinstance(upper, datetime.time):
+        upper = datetime.datetime.combine(base_date, upper)
+
+    if any(date is None for date in (lower, upper)):
+        return lower, upper
+
+    if upper < lower:
+        upper += datetime.timedelta(days=1)
+    return lower, upper
diff --git a/docs/apache-airflow/howto/operator/datetime_branch.rst b/docs/apache-airflow/howto/operator/datetime_branch.rst
new file mode 100644
index 0000000..b798468
--- /dev/null
+++ b/docs/apache-airflow/howto/operator/datetime_branch.rst
@@ -0,0 +1,39 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+    or more contributor license agreements.  See the NOTICE file
+    distributed with this work for additional information
+    regarding copyright ownership.  The ASF licenses this file
+    to you under the Apache License, Version 2.0 (the
+    "License"); you may not use this file except in compliance
+    with the License.  You may obtain a copy of the License at
+
+ ..   http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+    software distributed under the License is distributed on an
+    "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+    KIND, either express or implied.  See the License for the
+    specific language governing permissions and limitations
+    under the License.
+
+
+
+.. _howto/operator:DatetimeBranch:
+
+DatetimeBranchOperator
+======================
+
+Use the :class:`~airflow.operators.datetime_branch.DatetimeBranchOperator` to branch into one of two execution paths depending on whether the date and/or time of execution falls into the range given by two target arguments.
+
+.. exampleinclude:: /../../airflow/example_dags/example_datetime_branch_operator.py
+    :language: python
+    :start-after: [START howto_operator_datetime_branch]
+    :end-before: [END howto_operator_datetime_branch]
+
+The target parameters, ``target_upper`` and ``target_lower``, can receive a ``datetime.datetime``, a ``datetime.time``, or ``None``. When a ``datetime.time`` object is used, it will be combined with the current date in order to allow comparisons with it. In the event that ``target_upper`` is set to a ``datetime.time`` that occurs before the given ``target_lower``, a day will be added to ``target_upper``. This is done to allow for time periods that span over two dates.
+
+.. exampleinclude:: /../../airflow/example_dags/example_datetime_branch_operator.py
+    :language: python
+    :start-after: [START howto_operator_datetime_branch_next_day]
+    :end-before: [END howto_operator_datetime_branch_next_day]
+
+If a target parameter is set to ``None``, the operator will perform a unilateral comparison using only the non-``None`` target. Setting both ``target_upper`` and ``target_lower`` to ``None`` will raise an exception.
diff --git a/docs/apache-airflow/howto/operator/index.rst b/docs/apache-airflow/howto/operator/index.rst
index 2cbb489..71cdf66 100644
--- a/docs/apache-airflow/howto/operator/index.rst
+++ b/docs/apache-airflow/howto/operator/index.rst
@@ -32,6 +32,7 @@ determine what actually executes when your DAG runs.
     :maxdepth: 2
 
     bash
+    datetime_branch
     python
     weekday
     external_task_sensor
diff --git a/tests/operators/test_datetime_branch.py b/tests/operators/test_datetime_branch.py
new file mode 100644
index 0000000..cfc2a86
--- /dev/null
+++ b/tests/operators/test_datetime_branch.py
@@ -0,0 +1,251 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import datetime
+import unittest
+
+import freezegun
+
+from airflow.exceptions import AirflowException
+from airflow.models import DAG, DagRun, TaskInstance as TI
+from airflow.operators.datetime_branch import DateTimeBranchOperator
+from airflow.operators.dummy_operator import DummyOperator
+from airflow.utils import timezone
+from airflow.utils.session import create_session
+from airflow.utils.state import State
+
+DEFAULT_DATE = timezone.datetime(2016, 1, 1)
+INTERVAL = datetime.timedelta(hours=12)
+
+
+class TestDateTimeBranchOperator(unittest.TestCase):
+    @classmethod
+    def setUpClass(cls):
+        super().setUpClass()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+        cls.targets = [
+            (datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
+            (datetime.time(10, 0, 0), datetime.time(11, 0, 0)),
+            (datetime.datetime(2020, 7, 7, 10, 0, 0), datetime.time(11, 0, 0)),
+            (datetime.time(10, 0, 0), datetime.datetime(2020, 7, 7, 11, 0, 0)),
+            (datetime.time(11, 0, 0), datetime.time(10, 0, 0)),
+        ]
+
+    def setUp(self):
+        self.dag = DAG(
+            'datetime_branch_operator_test',
+            default_args={'owner': 'airflow', 'start_date': DEFAULT_DATE},
+            schedule_interval=INTERVAL,
+        )
+
+        self.branch_1 = DummyOperator(task_id='branch_1', dag=self.dag)
+        self.branch_2 = DummyOperator(task_id='branch_2', dag=self.dag)
+
+        self.branch_op = DateTimeBranchOperator(
+            task_id='datetime_branch',
+            follow_task_ids_if_true='branch_1',
+            follow_task_ids_if_false='branch_2',
+            target_upper=datetime.datetime(2020, 7, 7, 11, 0, 0),
+            target_lower=datetime.datetime(2020, 7, 7, 10, 0, 0),
+            dag=self.dag,
+        )
+
+        self.branch_1.set_upstream(self.branch_op)
+        self.branch_2.set_upstream(self.branch_op)
+        self.dag.clear()
+
+        self.dr = self.dag.create_dagrun(
+            run_id='manual__', start_date=DEFAULT_DATE, execution_date=DEFAULT_DATE, state=State.RUNNING
+        )
+
+    def tearDown(self):
+        super().tearDown()
+
+        with create_session() as session:
+            session.query(DagRun).delete()
+            session.query(TI).delete()
+
+    def _assert_task_ids_match_states(self, task_ids_to_states):
+        """Helper that asserts task instances with a given id are in a given state"""
+        tis = self.dr.get_task_instances()
+        for ti in tis:
+            try:
+                expected_state = task_ids_to_states[ti.task_id]
+            except KeyError:
+                raise ValueError(f'Invalid task id {ti.task_id} found!')
+            else:
+                self.assertEqual(
+                    ti.state,
+                    expected_state,
+                    f"Task {ti.task_id} has state {ti.state} instead of expected {expected_state}",
+                )
+
+    def test_no_target_time(self):
+        """Check if DateTimeBranchOperator raises exception on missing target"""
+        with self.assertRaises(AirflowException):
+            DateTimeBranchOperator(
+                task_id='datetime_branch',
+                follow_task_ids_if_true='branch_1',
+                follow_task_ids_if_false='branch_2',
+                target_upper=None,
+                target_lower=None,
+                dag=self.dag,
+            )
+
+    @freezegun.freeze_time("2020-07-07 10:54:05")
+    def test_datetime_branch_operator_falls_within_range(self):
+        """Check DateTimeBranchOperator branch operation"""
+        for target_lower, target_upper in self.targets:
+            with self.subTest(target_lower=target_lower, target_upper=target_upper):
+                self.branch_op.target_lower = target_lower
+                self.branch_op.target_upper = target_upper
+                self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+                self._assert_task_ids_match_states(
+                    {
+                        'datetime_branch': State.SUCCESS,
+                        'branch_1': State.NONE,
+                        'branch_2': State.SKIPPED,
+                    }
+                )
+
+    def test_datetime_branch_operator_falls_outside_range(self):
+        """Check DateTimeBranchOperator branch operation"""
+        dates = [
+            datetime.datetime(2020, 7, 7, 12, 0, 0, tzinfo=datetime.timezone.utc),
+            datetime.datetime(2020, 6, 7, 12, 0, 0, tzinfo=datetime.timezone.utc),
+        ]
+
+        for target_lower, target_upper in self.targets:
+            with self.subTest(target_lower=target_lower, target_upper=target_upper):
+                self.branch_op.target_lower = target_lower
+                self.branch_op.target_upper = target_upper
+
+                for date in dates:
+                    with freezegun.freeze_time(date):
+                        self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+                        self._assert_task_ids_match_states(
+                            {
+                                'datetime_branch': State.SUCCESS,
+                                'branch_1': State.SKIPPED,
+                                'branch_2': State.NONE,
+                            }
+                        )
+
+    @freezegun.freeze_time("2020-07-07 10:54:05")
+    def test_datetime_branch_operator_upper_comparison_within_range(self):
+        """Check DateTimeBranchOperator branch operation"""
+        for _, target_upper in self.targets:
+            with self.subTest(target_upper=target_upper):
+                self.branch_op.target_upper = target_upper
+                self.branch_op.target_lower = None
+
+                self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+                self._assert_task_ids_match_states(
+                    {
+                        'datetime_branch': State.SUCCESS,
+                        'branch_1': State.NONE,
+                        'branch_2': State.SKIPPED,
+                    }
+                )
+
+    @freezegun.freeze_time("2020-07-07 10:54:05")
+    def test_datetime_branch_operator_lower_comparison_within_range(self):
+        """Check DateTimeBranchOperator branch operation"""
+        for target_lower, _ in self.targets:
+            with self.subTest(target_lower=target_lower):
+                self.branch_op.target_lower = target_lower
+                self.branch_op.target_upper = None
+
+                self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+                self._assert_task_ids_match_states(
+                    {
+                        'datetime_branch': State.SUCCESS,
+                        'branch_1': State.NONE,
+                        'branch_2': State.SKIPPED,
+                    }
+                )
+
+    @freezegun.freeze_time("2020-07-07 12:00:00")
+    def test_datetime_branch_operator_upper_comparison_outside_range(self):
+        """Check DateTimeBranchOperator branch operation"""
+        for _, target_upper in self.targets:
+            with self.subTest(target_upper=target_upper):
+                self.branch_op.target_upper = target_upper
+                self.branch_op.target_lower = None
+
+                self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+                self._assert_task_ids_match_states(
+                    {
+                        'datetime_branch': State.SUCCESS,
+                        'branch_1': State.SKIPPED,
+                        'branch_2': State.NONE,
+                    }
+                )
+
+    @freezegun.freeze_time("2020-07-07 09:00:00")
+    def test_datetime_branch_operator_lower_comparison_outside_range(self):
+        """Check DateTimeBranchOperator branch operation"""
+        for target_lower, _ in self.targets:
+            with self.subTest(target_lower=target_lower):
+                self.branch_op.target_lower = target_lower
+                self.branch_op.target_upper = None
+
+                self.branch_op.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE)
+
+                self._assert_task_ids_match_states(
+                    {
+                        'datetime_branch': State.SUCCESS,
+                        'branch_1': State.SKIPPED,
+                        'branch_2': State.NONE,
+                    }
+                )
+
+    @freezegun.freeze_time("2020-12-01 09:00:00")
+    def test_datetime_branch_operator_use_task_execution_date(self):
+        """Check if DateTimeBranchOperator uses task execution date"""
+        in_between_date = timezone.datetime(2020, 7, 7, 10, 30, 0)
+        self.branch_op.use_task_execution_date = True
+        self.dr = self.dag.create_dagrun(
+            run_id='manual_exec_date__',
+            start_date=in_between_date,
+            execution_date=in_between_date,
+            state=State.RUNNING,
+        )
+
+        for target_lower, target_upper in self.targets:
+            with self.subTest(target_lower=target_lower, target_upper=target_upper):
+                self.branch_op.target_lower = target_lower
+                self.branch_op.target_upper = target_upper
+                self.branch_op.run(start_date=in_between_date, end_date=in_between_date)
+
+                self._assert_task_ids_match_states(
+                    {
+                        'datetime_branch': State.SUCCESS,
+                        'branch_1': State.NONE,
+                        'branch_2': State.SKIPPED,
+                    }
+                )