You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by el...@apache.org on 2022/10/26 11:01:20 UTC
[airflow] branch main updated: Adding `preserve_file_name` param to `S3Hook.download_file` method (#26886)
This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 777b57f0c6 Adding `preserve_file_name` param to `S3Hook.download_file` method (#26886)
777b57f0c6 is described below
commit 777b57f0c6a8ca16df2b96fd17c26eab56b3f268
Author: Alex Kruchkov <36...@users.noreply.github.com>
AuthorDate: Wed Oct 26 14:01:10 2022 +0300
Adding `preserve_file_name` param to `S3Hook.download_file` method (#26886)
* Adding `preserve_file_name` param to `S3Hook.download_file` method
---
airflow/providers/amazon/aws/hooks/s3.py | 48 ++++++++++++++--
tests/providers/amazon/aws/hooks/test_s3.py | 89 ++++++++++++++++++++++++++---
2 files changed, 123 insertions(+), 14 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index 5a440f39cc..a0d1dd03a6 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -29,9 +29,10 @@ from functools import wraps
from inspect import signature
from io import BytesIO
from pathlib import Path
-from tempfile import NamedTemporaryFile
+from tempfile import NamedTemporaryFile, gettempdir
from typing import Any, Callable, List, TypeVar, cast
from urllib.parse import urlparse
+from uuid import uuid4
from boto3.s3.transfer import S3Transfer, TransferConfig
from botocore.exceptions import ClientError
@@ -879,7 +880,14 @@ class S3Hook(AwsBaseHook):
@provide_bucket_name
@unify_bucket_name_and_key
- def download_file(self, key: str, bucket_name: str | None = None, local_path: str | None = None) -> str:
+ def download_file(
+ self,
+ key: str,
+ bucket_name: str | None = None,
+ local_path: str | None = None,
+ preserve_file_name: bool = False,
+ use_autogenerated_subdir: bool = True,
+ ) -> str:
"""
Downloads a file from the S3 location to the local file system.
@@ -887,9 +895,23 @@ class S3Hook(AwsBaseHook):
:param bucket_name: The specific bucket to use.
:param local_path: The local path to the downloaded file. If no path is provided it will use the
system's temporary directory.
+ :param preserve_file_name: If you want the downloaded file name to be the same name as it is in S3,
+ set this parameter to True. When set to False, a random filename will be generated.
+ Default: False.
+ :param use_autogenerated_subdir: Pairs with 'preserve_file_name = True' to download the file into a
+ random generated folder inside the 'local_path', useful to avoid collisions between various tasks
+ that might download the same file name. Set it to 'False' if you don't want it, and you want a
+ predictable path.
+ Default: True.
:return: the file name.
:rtype: str
"""
+ self.log.info(
+ "This function shadows the 'download_file' method of S3 API, but it is not the same. If you "
+ "want to use the original method from S3 API, please call "
+ "'S3Hook.get_conn().download_file()'"
+ )
+
self.log.info("Downloading source S3 file from Bucket %s with path %s", bucket_name, key)
try:
@@ -902,14 +924,30 @@ class S3Hook(AwsBaseHook):
else:
raise e
- with NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) as local_tmp_file:
+ if preserve_file_name:
+ local_dir = local_path if local_path else gettempdir()
+ subdir = f"airflow_tmp_dir_{uuid4().hex[0:8]}" if use_autogenerated_subdir else ""
+ filename_in_s3 = s3_obj.key.rsplit("/", 1)[-1]
+ file_path = Path(local_dir, subdir, filename_in_s3)
+
+ if file_path.is_file():
+ self.log.error("file '%s' already exists. Failing the task and not overwriting it", file_path)
+ raise FileExistsError
+
+ file_path.parent.mkdir(exist_ok=True, parents=True)
+
+ file = open(file_path, "wb")
+ else:
+ file = NamedTemporaryFile(dir=local_path, prefix="airflow_tmp_", delete=False) # type: ignore
+
+ with file:
s3_obj.download_fileobj(
- local_tmp_file,
+ file,
ExtraArgs=self.extra_args,
Config=self.transfer_config,
)
- return local_tmp_file.name
+ return file.name
def generate_presigned_url(
self,
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index f465366ab8..635431fede 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import gzip as gz
import os
import tempfile
+from pathlib import Path
from unittest import mock
from unittest.mock import Mock
@@ -532,24 +533,94 @@ class TestAwsS3Hook:
@mock.patch("airflow.providers.amazon.aws.hooks.s3.NamedTemporaryFile")
def test_download_file(self, mock_temp_file):
- mock_temp_file.return_value.__enter__ = Mock(return_value=mock_temp_file)
+ with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as temp_file:
+ mock_temp_file.return_value = temp_file
+ s3_hook = S3Hook(aws_conn_id="s3_test")
+ s3_hook.check_for_key = Mock(return_value=True)
+ s3_obj = Mock()
+ s3_obj.download_fileobj = Mock(return_value=None)
+ s3_hook.get_key = Mock(return_value=s3_obj)
+ key = "test_key"
+ bucket = "test_bucket"
+
+ output_file = s3_hook.download_file(key=key, bucket_name=bucket)
+
+ s3_hook.get_key.assert_called_once_with(key, bucket)
+ s3_obj.download_fileobj.assert_called_once_with(
+ temp_file,
+ Config=s3_hook.transfer_config,
+ ExtraArgs=s3_hook.extra_args,
+ )
+
+ assert temp_file.name == output_file
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
+ def test_download_file_with_preserve_name(self, mock_open):
+ file_name = "test.log"
+ bucket = "test_bucket"
+ key = f"test_key/{file_name}"
+ local_folder = "/tmp"
+
s3_hook = S3Hook(aws_conn_id="s3_test")
s3_hook.check_for_key = Mock(return_value=True)
s3_obj = Mock()
+ s3_obj.key = f"s3://{bucket}/{key}"
s3_obj.download_fileobj = Mock(return_value=None)
s3_hook.get_key = Mock(return_value=s3_obj)
- key = "test_key"
- bucket = "test_bucket"
+ s3_hook.download_file(
+ key=key,
+ bucket_name=bucket,
+ local_path=local_folder,
+ preserve_file_name=True,
+ use_autogenerated_subdir=False,
+ )
- s3_hook.download_file(key=key, bucket_name=bucket)
+ mock_open.assert_called_once_with(Path(local_folder, file_name), "wb")
- s3_hook.get_key.assert_called_once_with(key, bucket)
- s3_obj.download_fileobj.assert_called_once_with(
- mock_temp_file,
- Config=s3_hook.transfer_config,
- ExtraArgs=s3_hook.extra_args,
+ @mock.patch("airflow.providers.amazon.aws.hooks.s3.open")
+ def test_download_file_with_preserve_name_with_autogenerated_subdir(self, mock_open):
+ file_name = "test.log"
+ bucket = "test_bucket"
+ key = f"test_key/{file_name}"
+ local_folder = "/tmp"
+
+ s3_hook = S3Hook(aws_conn_id="s3_test")
+ s3_hook.check_for_key = Mock(return_value=True)
+ s3_obj = Mock()
+ s3_obj.key = f"s3://{bucket}/{key}"
+ s3_obj.download_fileobj = Mock(return_value=None)
+ s3_hook.get_key = Mock(return_value=s3_obj)
+ result_file = s3_hook.download_file(
+ key=key,
+ bucket_name=bucket,
+ local_path=local_folder,
+ preserve_file_name=True,
+ use_autogenerated_subdir=True,
)
+ assert result_file.rsplit("/", 1)[-2].startswith("airflow_tmp_dir_")
+
+ def test_download_file_with_preserve_name_file_already_exists(self):
+ with tempfile.NamedTemporaryFile(dir="/tmp", prefix="airflow_tmp_test_s3_hook") as file:
+ file_name = file.name.rsplit("/", 1)[-1]
+ bucket = "test_bucket"
+ key = f"test_key/{file_name}"
+ local_folder = "/tmp"
+ s3_hook = S3Hook(aws_conn_id="s3_test")
+ s3_hook.check_for_key = Mock(return_value=True)
+ s3_obj = Mock()
+ s3_obj.key = f"s3://{bucket}/{key}"
+ s3_obj.download_fileobj = Mock(return_value=None)
+ s3_hook.get_key = Mock(return_value=s3_obj)
+ with pytest.raises(FileExistsError):
+ s3_hook.download_file(
+ key=key,
+ bucket_name=bucket,
+ local_path=local_folder,
+ preserve_file_name=True,
+ use_autogenerated_subdir=False,
+ )
+
def test_generate_presigned_url(self, s3_bucket):
hook = S3Hook()
presigned_url = hook.generate_presigned_url(