You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2023/02/28 05:46:14 UTC

[airflow] branch main updated: Fix SFTPSensor when using newer_than and there are multiple matched files (#29794)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9357c81828 Fix SFTPSensor when using newer_than and there are multiple matched files (#29794)
9357c81828 is described below

commit 9357c81828626754c990c3e8192880511a510544
Author: Hussein Awala <ho...@gmail.com>
AuthorDate: Tue Feb 28 06:45:59 2023 +0100

    Fix SFTPSensor when using newer_than and there are multiple matched files (#29794)
    
    * Add a method to SFTP hook to get all matched files
    
    * fix SFTPSensor with newer_than when there are multiple matched files
---
 airflow/providers/sftp/hooks/sftp.py      | 15 ++++++
 airflow/providers/sftp/sensors/sftp.py    | 40 +++++++-------
 tests/providers/sftp/hooks/test_sftp.py   | 16 ++++++
 tests/providers/sftp/sensors/test_sftp.py | 87 ++++++++++++++++++++++++++++++-
 4 files changed, 138 insertions(+), 20 deletions(-)

diff --git a/airflow/providers/sftp/hooks/sftp.py b/airflow/providers/sftp/hooks/sftp.py
index 450b911003..c1f0052d19 100644
--- a/airflow/providers/sftp/hooks/sftp.py
+++ b/airflow/providers/sftp/hooks/sftp.py
@@ -391,3 +391,18 @@ class SFTPHook(SSHHook):
                 return file
 
         return ""
+
+    def get_files_by_pattern(self, path, fnmatch_pattern) -> list[str]:
+        """
+        Returning the list of matching files based on the given fnmatch type pattern
+
+        :param path: path to be checked
+        :param fnmatch_pattern: The pattern that will be matched with `fnmatch`
+        :return: list of string containing the found files, or an empty list if none matched
+        """
+        matched_files = []
+        for file in self.list_directory(path):
+            if fnmatch(file, fnmatch_pattern):
+                matched_files.append(file)
+
+        return matched_files
diff --git a/airflow/providers/sftp/sensors/sftp.py b/airflow/providers/sftp/sensors/sftp.py
index 51af1c538f..8a84ee2f14 100644
--- a/airflow/providers/sftp/sensors/sftp.py
+++ b/airflow/providers/sftp/sensors/sftp.py
@@ -68,25 +68,29 @@ class SFTPSensor(BaseSensorOperator):
         self.log.info("Poking for %s, with pattern %s", self.path, self.file_pattern)
 
         if self.file_pattern:
-            file_from_pattern = self.hook.get_file_by_pattern(self.path, self.file_pattern)
-            if file_from_pattern:
-                actual_file_to_check = os.path.join(self.path, file_from_pattern)
+            files_from_pattern = self.hook.get_files_by_pattern(self.path, self.file_pattern)
+            if files_from_pattern:
+                actual_files_to_check = [
+                    os.path.join(self.path, file_from_pattern) for file_from_pattern in files_from_pattern
+                ]
             else:
                 return False
         else:
-            actual_file_to_check = self.path
-
-        try:
-            mod_time = self.hook.get_mod_time(actual_file_to_check)
-            self.log.info("Found File %s last modified: %s", str(actual_file_to_check), str(mod_time))
-        except OSError as e:
-            if e.errno != SFTP_NO_SUCH_FILE:
-                raise e
-            return False
+            actual_files_to_check = [self.path]
+        for actual_file_to_check in actual_files_to_check:
+            try:
+                mod_time = self.hook.get_mod_time(actual_file_to_check)
+                self.log.info("Found File %s last modified: %s", str(actual_file_to_check), str(mod_time))
+            except OSError as e:
+                if e.errno != SFTP_NO_SUCH_FILE:
+                    raise e
+                continue
+            if self.newer_than:
+                _mod_time = convert_to_utc(datetime.strptime(mod_time, "%Y%m%d%H%M%S"))
+                _newer_than = convert_to_utc(self.newer_than)
+                if _newer_than <= _mod_time:
+                    return True
+            else:
+                return True
         self.hook.close_conn()
-        if self.newer_than:
-            _mod_time = convert_to_utc(datetime.strptime(mod_time, "%Y%m%d%H%M%S"))
-            _newer_than = convert_to_utc(self.newer_than)
-            return _newer_than <= _mod_time
-        else:
-            return True
+        return False
diff --git a/tests/providers/sftp/hooks/test_sftp.py b/tests/providers/sftp/hooks/test_sftp.py
index 4d7f7bb562..76e41c0fa7 100644
--- a/tests/providers/sftp/hooks/test_sftp.py
+++ b/tests/providers/sftp/hooks/test_sftp.py
@@ -423,6 +423,22 @@ class TestSFTPHook:
         output = self.hook.get_file_by_pattern(TMP_PATH, "*_file_*.txt")
         assert output == ANOTHER_FILE_FOR_TESTS
 
+    def test_get_none_matched_files(self):
+        output = self.hook.get_files_by_pattern(TMP_PATH, "*.text")
+        assert output == []
+
+    def test_get_matched_files_several_pattern(self):
+        output = self.hook.get_files_by_pattern(TMP_PATH, "*.log")
+        assert output == [LOG_FILE_FOR_TESTS]
+
+    def test_get_all_matched_files(self):
+        output = self.hook.get_files_by_pattern(TMP_PATH, "test_*.txt")
+        assert output == [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS]
+
+    def test_get_matched_files_with_different_pattern(self):
+        output = self.hook.get_files_by_pattern(TMP_PATH, "*_file_*.txt")
+        assert output == [ANOTHER_FILE_FOR_TESTS]
+
     def teardown_method(self):
         shutil.rmtree(os.path.join(TMP_PATH, TMP_DIR_FOR_TESTS))
         for file_name in [TMP_FILE_FOR_TESTS, ANOTHER_FILE_FOR_TESTS, LOG_FILE_FOR_TESTS]:
diff --git a/tests/providers/sftp/sensors/test_sftp.py b/tests/providers/sftp/sensors/test_sftp.py
index a2cb5a3bde..6895c158c8 100644
--- a/tests/providers/sftp/sensors/test_sftp.py
+++ b/tests/providers/sftp/sensors/test_sftp.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 from datetime import datetime
+from unittest import mock
 from unittest.mock import patch
 
 import pytest
@@ -101,7 +102,7 @@ class TestSFTPSensor:
     @patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
     def test_file_present_with_pattern(self, sftp_hook_mock):
         sftp_hook_mock.return_value.get_mod_time.return_value = "19700101000000"
-        sftp_hook_mock.return_value.get_file_by_pattern.return_value = "text_file.txt"
+        sftp_hook_mock.return_value.get_files_by_pattern.return_value = ["text_file.txt"]
         sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/", file_pattern="*.txt")
         context = {"ds": "1970-01-01"}
         output = sftp_sensor.poke(context)
@@ -111,8 +112,90 @@ class TestSFTPSensor:
     @patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
     def test_file_not_present_with_pattern(self, sftp_hook_mock):
         sftp_hook_mock.return_value.get_mod_time.return_value = "19700101000000"
-        sftp_hook_mock.return_value.get_file_by_pattern.return_value = ""
+        sftp_hook_mock.return_value.get_files_by_pattern.return_value = []
         sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/", file_pattern="*.txt")
         context = {"ds": "1970-01-01"}
         output = sftp_sensor.poke(context)
         assert not output
+
+    @patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
+    def test_multiple_file_present_with_pattern(self, sftp_hook_mock):
+        sftp_hook_mock.return_value.get_mod_time.return_value = "19700101000000"
+        sftp_hook_mock.return_value.get_files_by_pattern.return_value = [
+            "text_file.txt",
+            "another_text_file.txt",
+        ]
+        sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/", file_pattern="*.txt")
+        context = {"ds": "1970-01-01"}
+        output = sftp_sensor.poke(context)
+        sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/text_file.txt")
+        assert output
+
+    @patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
+    def test_multiple_files_present_with_pattern(self, sftp_hook_mock):
+        sftp_hook_mock.return_value.get_mod_time.return_value = "19700101000000"
+        sftp_hook_mock.return_value.get_files_by_pattern.return_value = [
+            "text_file.txt",
+            "another_text_file.txt",
+        ]
+        sftp_sensor = SFTPSensor(task_id="unit_test", path="/path/to/file/", file_pattern="*.txt")
+        context = {"ds": "1970-01-01"}
+        output = sftp_sensor.poke(context)
+        sftp_hook_mock.return_value.get_mod_time.assert_called_once_with("/path/to/file/text_file.txt")
+        assert output
+
+    @patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
+    def test_multiple_files_present_with_pattern_and_newer_than(self, sftp_hook_mock):
+        sftp_hook_mock.return_value.get_files_by_pattern.return_value = [
+            "text_file1.txt",
+            "text_file2.txt",
+            "text_file3.txt",
+        ]
+        sftp_hook_mock.return_value.get_mod_time.side_effect = [
+            "19500101000000",
+            "19700101000000",
+            "19800101000000",
+        ]
+        tz = timezone("America/Toronto")
+        sftp_sensor = SFTPSensor(
+            task_id="unit_test",
+            path="/path/to/file/",
+            file_pattern="*.txt",
+            newer_than=tz.convert(datetime(1960, 1, 2)),
+        )
+        context = {"ds": "1970-01-00"}
+        output = sftp_sensor.poke(context)
+        sftp_hook_mock.return_value.get_mod_time.assert_has_calls(
+            [mock.call("/path/to/file/text_file1.txt"), mock.call("/path/to/file/text_file2.txt")]
+        )
+        assert output
+
+    @patch("airflow.providers.sftp.sensors.sftp.SFTPHook")
+    def test_multiple_old_files_present_with_pattern_and_newer_than(self, sftp_hook_mock):
+        sftp_hook_mock.return_value.get_files_by_pattern.return_value = [
+            "text_file1.txt",
+            "text_file2.txt",
+            "text_file3.txt",
+        ]
+        sftp_hook_mock.return_value.get_mod_time.side_effect = [
+            "19500101000000",
+            "19510101000000",
+            "19520101000000",
+        ]
+        tz = timezone("America/Toronto")
+        sftp_sensor = SFTPSensor(
+            task_id="unit_test",
+            path="/path/to/file/",
+            file_pattern="*.txt",
+            newer_than=tz.convert(datetime(1960, 1, 2)),
+        )
+        context = {"ds": "1970-01-00"}
+        output = sftp_sensor.poke(context)
+        sftp_hook_mock.return_value.get_mod_time.assert_has_calls(
+            [
+                mock.call("/path/to/file/text_file1.txt"),
+                mock.call("/path/to/file/text_file2.txt"),
+                mock.call("/path/to/file/text_file3.txt"),
+            ]
+        )
+        assert not output