You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by ka...@apache.org on 2021/08/20 18:42:42 UTC

[airflow] branch main updated: Ensure ``DateTimeTrigger`` receives a datetime object (#17747)

This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e21b54a  Ensure ``DateTimeTrigger`` receives a datetime object (#17747)
e21b54a is described below

commit e21b54a62e43d236ca7db0b180737f36699f88eb
Author: Kaxil Naik <ka...@gmail.com>
AuthorDate: Fri Aug 20 19:42:12 2021 +0100

    Ensure ``DateTimeTrigger`` receives a datetime object (#17747)
    
    While using the following example DAG, the task failed with `moment` does not have `tzinfo` attribute. This happened because a string was passed from `DateTimeSensor` to `DateTimeTrigger`. This PR ensures that a datetime object is passed and fixed logic in `DateTimeSensor` too so that `self.target_time` is always a datetime object.
    
    ```python
    from datetime import timedelta
    
    from airflow import DAG
    from airflow.sensors.date_time import DateTimeSensorAsync
    from airflow.utils import dates, timezone
    
    with DAG(
        dag_id='example_date_time_async_operator',
        schedule_interval='0 0 * * *',
        start_date=dates.days_ago(2),
        dagrun_timeout=timedelta(minutes=60),
        tags=['example', 'example2', 'async'],
    ) as dag:
    
        DateTimeSensorAsync(task_id="test", target_time=timezone.datetime(2021, 8, 19, 23, 15, 0))
    
    ```
---
 airflow/sensors/date_time.py    | 8 +++++---
 airflow/triggers/temporal.py    | 4 +++-
 tests/sensors/test_date_time.py | 9 +++++++--
 tests/triggers/test_temporal.py | 8 ++++++++
 4 files changed, 23 insertions(+), 6 deletions(-)

diff --git a/airflow/sensors/date_time.py b/airflow/sensors/date_time.py
index 3eca931..b9452cb 100644
--- a/airflow/sensors/date_time.py
+++ b/airflow/sensors/date_time.py
@@ -59,9 +59,11 @@ class DateTimeSensor(BaseSensorOperator):
     def __init__(self, *, target_time: Union[str, datetime.datetime], **kwargs) -> None:
         super().__init__(**kwargs)
         if isinstance(target_time, datetime.datetime):
-            self.target_time = target_time.isoformat()
-        elif isinstance(target_time, str):
+            if timezone.is_naive(target_time):
+                target_time = timezone.make_aware(target_time)
             self.target_time = target_time
+        elif isinstance(target_time, str):
+            self.target_time = timezone.parse(target_time)
         else:
             raise TypeError(
                 f"Expected str or datetime.datetime type for target_time. Got {type(target_time)}"
@@ -69,7 +71,7 @@ class DateTimeSensor(BaseSensorOperator):
 
     def poke(self, context: Dict) -> bool:
         self.log.info("Checking if the time (%s) has come", self.target_time)
-        return timezone.utcnow() > timezone.parse(self.target_time)
+        return timezone.utcnow() > self.target_time
 
 
 class DateTimeSensorAsync(DateTimeSensor):
diff --git a/airflow/triggers/temporal.py b/airflow/triggers/temporal.py
index 26685e3..7cd67e2 100644
--- a/airflow/triggers/temporal.py
+++ b/airflow/triggers/temporal.py
@@ -33,8 +33,10 @@ class DateTimeTrigger(BaseTrigger):
 
     def __init__(self, moment: datetime.datetime):
         super().__init__()
+        if not isinstance(moment, datetime.datetime):
+            raise TypeError(f"Expected datetime.datetime type for moment. Got {type(moment)}")
         # Make sure it's in UTC
-        if moment.tzinfo is None:
+        elif moment.tzinfo is None:
             raise ValueError("You cannot pass naive datetimes")
         elif not hasattr(moment.tzinfo, "offset") or moment.tzinfo.offset != 0:
             raise ValueError(f"The passed datetime must be using Pendulum's UTC, not {moment.tzinfo!r}")
diff --git a/tests/sensors/test_date_time.py b/tests/sensors/test_date_time.py
index f201189..a283bd9 100644
--- a/tests/sensors/test_date_time.py
+++ b/tests/sensors/test_date_time.py
@@ -15,6 +15,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+import datetime
 from unittest.mock import patch
 
 import pytest
@@ -38,9 +39,13 @@ class TestDateTimeSensor:
             (
                 "valid_datetime",
                 timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc),
-                "2020-07-06T13:00:00+00:00",
+                timezone.datetime(2020, 7, 6, 13, tzinfo=timezone.utc),
+            ),
+            (
+                "valid_str",
+                "20200706T210000+8",
+                timezone.datetime(2020, 7, 6, 21, tzinfo=datetime.timezone(datetime.timedelta(hours=8))),
             ),
-            ("valid_str", "20200706T210000+8", "20200706T210000+8"),
         ]
     )
     def test_valid_input(self, task_id, target_time, expected):
diff --git a/tests/triggers/test_temporal.py b/tests/triggers/test_temporal.py
index b541c8f..02baadf 100644
--- a/tests/triggers/test_temporal.py
+++ b/tests/triggers/test_temporal.py
@@ -28,6 +28,14 @@ from airflow.triggers.temporal import DateTimeTrigger, TimeDeltaTrigger
 from airflow.utils import timezone
 
 
+def test_input_validation():
+    """
+    Tests that the DateTimeTrigger validates input to moment arg, it should only accept datetime.
+    """
+    with pytest.raises(TypeError, match="Expected datetime.datetime type for moment. Got <class 'str'>"):
+        DateTimeTrigger('2012-01-01T03:03:03+00:00')
+
+
 def test_datetime_trigger_serialization():
     """
     Tests that the DateTimeTrigger correctly serializes its arguments