You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/10/26 11:01:20 UTC

[airflow] branch main updated: Adding `preserve_file_name` param to `S3Hook.download_file` method (#26886)

This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 777b57f0c6 Adding `preserve_file_name` param to `S3Hook.download_file` method (#26886)
777b57f0c6 is described below

commit 777b57f0c6a8ca16df2b96fd17c26eab56b3f268
Author: Alex Kruchkov <36...@users.noreply.github.com>
AuthorDate: Wed Oct 26 14:01:10 2022 +0300

    Adding `preserve_file_name` param to `S3Hook.download_file` method (#26886)
    
    * Adding `preserve_file_name` param to `S3Hook.download_file` method
---
 airflow/providers/amazon/aws/hooks/s3.py    | 48 ++++++++++++++--
 tests/providers/amazon/aws/hooks/test_s3.py | 89 ++++++++++++++++++++++++++---
 2 files changed, 123 insertions(+), 14 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index 5a440f39cc..a0d1dd03a6 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -29,9 +29,10 @@ from functools import wraps
 from inspect import signature
 from io import BytesIO
 from pathlib import Path
-from tempfile import NamedTemporaryFile
+from tempfile import NamedTemporaryFile, gettempdir
 from typing import Any, Callable, List, TypeVar, cast
 from urllib.parse import urlparse
+from uuid import uuid4
 
 from boto3.s3.transfer import S3Transfer, TransferConfig
 from botocore.exceptions import ClientError
@@ -879,7 +880,14 @@ class S3Hook(AwsBaseHook):
 
     @provide_bucket_name
     @unify_bucket_name_and_key
-    def download_file(self, key: str, bucket_name: str | None = None, local_path: str | None = None) -> str:
+    def download_file(
+        self,
+        key: str,
+        bucket_name: str | None = None,
+        local_path: str | None = None,
+        preserve_file_name: bool = False,
+        use_autogenerated_subdir: bool = True,
+    ) -> str:
         """
         Downloads a file from the S3 location to the local file system.
 
@@ -887,9 +895,23 @@ class S3Hook(AwsBaseHook):
         :param bucket_name: The specific bucket to use.
         :param local_path: The local path to the downloaded file. If no path is provided it will use the
             system's temporary directory.
+        :param preserve_file_name: If you want the downloaded file name to be the same name as it is in S3,
+            set this parameter to True. When set to False, a random filename will be generated.
+            Default: False.
+        :param use_autogenerated_subdir: Pairs with 'preserve_file_name = True' to download the file into a
+            random generated folder inside the 'local_path', useful to avoid collisions between various tasks
+            that might download the same file name. Set it to 'False' if you don't want it, and you want a
+            predictable path.
+            Default: True.
         :return: the file name.
         :rtype: str
         """
+        self.log.info(
+            "This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
+            "want to use the original method from S3 API, please call "
+            "'S3Hook.get_conn().download_file()'"
+        )
+
         self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)
 
         try:
@@ -902,14 +924,30 @@ class S3Hook(AwsBaseHook):
             else:
                 raise e
 
-        with NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) as local_tmp_file:
+        if preserve_file_name:
+            local_dir = local_path if local_path else gettempdir()
+            subdir = f"airflow_tmp_dir_{uuid4().hex[0:8]}" if use_autogenerated_subdir else ""
+            filename_in_s3 = s3_obj.key.rsplit("/", 1)[-1]
+            file_path = Path(local_dir, subdir, filename_in_s3)
+
+            if file_path.is_file():
+                self.log.error("file '%s' already exists. Failing the task and not overwriting it", file_path)
+                raise FileExistsError
+
+            file_path.parent.mkdir(exist_ok=True, parents=True)
+
+            file = open(file_path, "wb")
+        else:
+            file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False)  # type: ignore
+
+        with file:
             s3_obj.download_fileobj(
-                local_tmp_file,
+                file,
                 ExtraArgs=self.extra_args,
                 Config=self.transfer_config,
             )
 
-        return local_tmp_file.name
+        return file.name
 
     def generate_presigned_url(
         self,
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index f465366ab8..635431fede 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import gzip as gz
 import os
 import tempfile
+from pathlib import Path
 from unittest import mock
 from unittest.mock import Mock
 
@@ -532,24 +533,94 @@ class TestAwsS3Hook:
 
     @mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
     def test_download_file(self, mock_temp_file):
-        mock_temp_file.return_value.__enter__ = Mock(return_value=mock_temp_file)
+        with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as temp_file:
+            mock_temp_file.return_value = temp_file
+            s3_hook = S3Hook(aws_conn_id="s3_test")
+            s3_hook.check_for_key = Mock(return_value=True)
+            s3_obj = Mock()
+            s3_obj.download_fileobj = Mock(return_value=None)
+            s3_hook.get_key = Mock(return_value=s3_obj)
+            key = "test_key"
+            bucket = "test_bucket"
+
+            output_file = s3_hook.download_file(key=key, bucket_name=bucket)
+
+            s3_hook.get_key.assert_called_once_with(key, bucket)
+            s3_obj.download_fileobj.assert_called_once_with(
+                temp_file,
+                Config=s3_hook.transfer_config,
+                ExtraArgs=s3_hook.extra_args,
+            )
+
+            assert temp_file.name == output_file
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
+    def test_download_file_with_preserve_name(self, mock_open):
+        file_name = "test.log"
+        bucket = "test_bucket"
+        key = f"test_key/{file_name}"
+        local_folder = "/tmp"
+
         s3_hook = S3Hook(aws_conn_id="s3_test")
         s3_hook.check_for_key = Mock(return_value=True)
         s3_obj = Mock()
+        s3_obj.key = f"s3://{bucket}/{key}"
         s3_obj.download_fileobj = Mock(return_value=None)
         s3_hook.get_key = Mock(return_value=s3_obj)
-        key = "test_key"
-        bucket = "test_bucket"
+        s3_hook.download_file(
+            key=key,
+            bucket_name=bucket,
+            local_path=local_folder,
+            preserve_file_name=True,
+            use_autogenerated_subdir=False,
+        )
 
-        s3_hook.download_file(key=key, bucket_name=bucket)
+        mock_open.assert_called_once_with(Path(local_folder, file_name), "wb")
 
-        s3_hook.get_key.assert_called_once_with(key, bucket)
-        s3_obj.download_fileobj.assert_called_once_with(
-            mock_temp_file,
-            Config=s3_hook.transfer_config,
-            ExtraArgs=s3_hook.extra_args,
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
+    def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open):
+        file_name = "test.log"
+        bucket = "test_bucket"
+        key = f"test_key/{file_name}"
+        local_folder = "/tmp"
+
+        s3_hook = S3Hook(aws_conn_id="s3_test")
+        s3_hook.check_for_key = Mock(return_value=True)
+        s3_obj = Mock()
+        s3_obj.key = f"s3://{bucket}/{key}"
+        s3_obj.download_fileobj = Mock(return_value=None)
+        s3_hook.get_key = Mock(return_value=s3_obj)
+        result_file = s3_hook.download_file(
+            key=key,
+            bucket_name=bucket,
+            local_path=local_folder,
+            preserve_file_name=True,
+            use_autogenerated_subdir=True,
         )
 
+        assert result_file.rsplit("/", 1)[-2].startswith("airflow_tmp_dir_")
+
+    def test_download_file_with_preserve_name_file_already_exists(self):
+        with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as file:
+            file_name = file.name.rsplit("/", 1)[-1]
+            bucket = "test_bucket"
+            key = f"test_key/{file_name}"
+            local_folder = "/tmp"
+            s3_hook = S3Hook(aws_conn_id="s3_test")
+            s3_hook.check_for_key = Mock(return_value=True)
+            s3_obj = Mock()
+            s3_obj.key = f"s3://{bucket}/{key}"
+            s3_obj.download_fileobj = Mock(return_value=None)
+            s3_hook.get_key = Mock(return_value=s3_obj)
+            with pytest.raises(FileExistsError):
+                s3_hook.download_file(
+                    key=key,
+                    bucket_name=bucket,
+                    local_path=local_folder,
+                    preserve_file_name=True,
+                    use_autogenerated_subdir=False,
+                )
+
     def test_generate_presigned_url(self, s3_bucket):
         hook = S3Hook()
         presigned_url = hook.generate_presigned_url(