You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/19 22:41:42 UTC
[airflow] branch main updated: Make extra_args in S3Hook immutable between calls (#24527)
This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7293e31f1c Make extra_args in S3Hook immutable between calls (#24527)
7293e31f1c is described below
commit 7293e31f1cf33f015867ac89ee00910fc9ae1972
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Mon Jun 20 02:41:36 2022 +0400
Make extra_args in S3Hook immutable between calls (#24527)
---
airflow/providers/amazon/aws/hooks/s3.py | 52 +++++++++++++++++++---------
tests/providers/amazon/aws/hooks/test_s3.py | 53 ++++++++++++++++++++++++++++-
2 files changed, 88 insertions(+), 17 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index fd130a5bdd..9c46b78685 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -23,6 +23,7 @@ import gzip as gz
import io
import re
import shutil
+from copy import deepcopy
from datetime import datetime
from functools import wraps
from inspect import signature
@@ -97,6 +98,15 @@ class S3Hook(AwsBaseHook):
"""
Interact with AWS S3, using the boto3 library.
+ :param transfer_config_args: Configuration object for managed S3 transfers.
+ :param extra_args: Extra arguments that may be passed to the download/upload operations.
+
+ .. seealso::
+ https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#s3-transfers
+
+ - For allowed upload extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS``.
+ - For allowed download extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_DOWNLOAD_ARGS``.
+
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -107,26 +117,32 @@ class S3Hook(AwsBaseHook):
conn_type = 's3'
hook_name = 'Amazon S3'
- def __init__(self, *args, **kwargs) -> None:
+ def __init__(
+ self,
+ aws_conn_id: Optional[str] = AwsBaseHook.default_conn_name,
+ transfer_config_args: Optional[Dict] = None,
+ extra_args: Optional[Dict] = None,
+ *args,
+ **kwargs,
+ ) -> None:
kwargs['client_type'] = 's3'
+ kwargs['aws_conn_id'] = aws_conn_id
+
+ if extra_args and not isinstance(extra_args, dict):
+ raise ValueError(f"transfer_config_args '{extra_args!r}' must be of type {dict}")
+ self.transfer_config = TransferConfig(**transfer_config_args or {})
- self.extra_args = {}
- if 'extra_args' in kwargs:
- self.extra_args = kwargs['extra_args']
- if not isinstance(self.extra_args, dict):
- raise ValueError(f"extra_args '{self.extra_args!r}' must be of type {dict}")
- del kwargs['extra_args']
-
- self.transfer_config = TransferConfig()
- if 'transfer_config_args' in kwargs:
- transport_config_args = kwargs['transfer_config_args']
- if not isinstance(transport_config_args, dict):
- raise ValueError(f"transfer_config_args '{transport_config_args!r} must be of type {dict}")
- self.transfer_config = TransferConfig(**transport_config_args)
- del kwargs['transfer_config_args']
+ if extra_args and not isinstance(extra_args, dict):
+ raise ValueError(f"extra_args '{extra_args!r}' must be of type {dict}")
+ self._extra_args = extra_args or {}
super().__init__(*args, **kwargs)
+ @property
+ def extra_args(self):
+ """Return hook's extra arguments (immutable)."""
+ return deepcopy(self._extra_args)
+
@staticmethod
def parse_s3_url(s3url: str) -> Tuple[str, str]:
"""
@@ -867,7 +883,11 @@ class S3Hook(AwsBaseHook):
raise e
with NamedTemporaryFile(dir=local_path, prefix='airflow_tmp_', delete=False) as local_tmp_file:
- s3_obj.download_fileobj(local_tmp_file)
+ s3_obj.download_fileobj(
+ local_tmp_file,
+ ExtraArgs=self.extra_args,
+ Config=self.transfer_config,
+ )
return local_tmp_file.name
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index da63f37580..35f77a6d17 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -491,7 +491,11 @@ class TestAwsS3Hook:
s3_hook.download_file(key=key, bucket_name=bucket)
s3_hook.get_key.assert_called_once_with(key, bucket)
- s3_obj.download_fileobj.assert_called_once_with(mock_temp_file)
+ s3_obj.download_fileobj.assert_called_once_with(
+ mock_temp_file,
+ Config=s3_hook.transfer_config,
+ ExtraArgs=s3_hook.extra_args,
+ )
def test_generate_presigned_url(self, s3_bucket):
hook = S3Hook()
@@ -525,6 +529,53 @@ class TestAwsS3Hook:
resource = boto3.resource('s3').Object(s3_bucket, 'my_key')
assert resource.get()['ContentLanguage'] == "value"
+ def test_that_extra_args_not_changed_between_calls(self, s3_bucket):
+ original = {
+ "Metadata": {"metakey": "metaval"},
+ "ACL": "private",
+ "ServerSideEncryption": "aws:kms",
+ "SSEKMSKeyId": "arn:aws:kms:region:acct-id:key/key-id",
+ }
+ s3_hook = S3Hook(aws_conn_id="s3_test", extra_args=original)
+ assert s3_hook.extra_args == original
+ assert s3_hook.extra_args is not original
+
+ dummy = mock.MagicMock()
+ s3_hook.check_for_key = Mock(return_value=False)
+ mock_upload_fileobj = s3_hook.conn.upload_fileobj = Mock(return_value=None)
+ mock_upload_file = s3_hook.conn.upload_file = Mock(return_value=None)
+
+ # First Call - load_file_obj.
+ s3_hook.load_file_obj(dummy, "mock_key", s3_bucket, encrypt=True, acl_policy="public-read")
+ first_call_extra_args = mock_upload_fileobj.call_args_list[0][1]["ExtraArgs"]
+ assert s3_hook.extra_args == original
+ assert first_call_extra_args is not s3_hook.extra_args
+
+ # Second Call - load_bytes.
+ s3_hook.load_string("dummy", "mock_key", s3_bucket, acl_policy="bucket-owner-full-control")
+ second_call_extra_args = mock_upload_fileobj.call_args_list[1][1]["ExtraArgs"]
+ assert s3_hook.extra_args == original
+ assert second_call_extra_args is not s3_hook.extra_args
+ assert second_call_extra_args != first_call_extra_args
+
+ # Third Call - load_string.
+ s3_hook.load_bytes(b"dummy", "mock_key", s3_bucket, encrypt=True)
+ third_call_extra_args = mock_upload_fileobj.call_args_list[2][1]["ExtraArgs"]
+ assert s3_hook.extra_args == original
+ assert third_call_extra_args is not s3_hook.extra_args
+ assert third_call_extra_args not in [first_call_extra_args, second_call_extra_args]
+
+ # Fourth Call - load_file.
+ s3_hook.load_file("/dummy.png", "mock_key", s3_bucket, encrypt=True, acl_policy="bucket-owner-read")
+ fourth_call_extra_args = mock_upload_file.call_args_list[0][1]["ExtraArgs"]
+ assert s3_hook.extra_args == original
+ assert fourth_call_extra_args is not s3_hook.extra_args
+ assert fourth_call_extra_args not in [
+ third_call_extra_args,
+ first_call_extra_args,
+ second_call_extra_args,
+ ]
+
@mock_s3
def test_get_bucket_tagging_no_tags_raises_error(self):
hook = S3Hook()