You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by po...@apache.org on 2022/06/19 22:41:42 UTC

[airflow] branch main updated: Make extra_args in S3Hook immutable between calls (#24527)

This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7293e31f1c Make extra_args in S3Hook immutable between calls (#24527)
7293e31f1c is described below

commit 7293e31f1cf33f015867ac89ee00910fc9ae1972
Author: Andrey Anshin <An...@taragol.is>
AuthorDate: Mon Jun 20 02:41:36 2022 +0400

    Make extra_args in S3Hook immutable between calls (#24527)
---
 airflow/providers/amazon/aws/hooks/s3.py    | 52 +++++++++++++++++++---------
 tests/providers/amazon/aws/hooks/test_s3.py | 53 ++++++++++++++++++++++++++++-
 2 files changed, 88 insertions(+), 17 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index fd130a5bdd..9c46b78685 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -23,6 +23,7 @@ import gzip as gz
 import io
 import re
 import shutil
+from copy import deepcopy
 from datetime import datetime
 from functools import wraps
 from inspect import signature
@@ -97,6 +98,15 @@ class S3Hook(AwsBaseHook):
     """
     Interact with AWS S3, using the boto3 library.
 
+    :param transfer_config_args: Configuration object for managed S3 transfers.
+    :param extra_args: Extra arguments that may be passed to the download/upload operations.
+
+    .. seealso::
+        https://boto3.amazonaws.com/v1/documentation/api/latest/reference/customizations/s3.html#s3-transfers
+
+        - For allowed upload extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_UPLOAD_ARGS``.
+        - For allowed download extra arguments see ``boto3.s3.transfer.S3Transfer.ALLOWED_DOWNLOAD_ARGS``.
+
     Additional arguments (such as ``aws_conn_id``) may be specified and
     are passed down to the underlying AwsBaseHook.
 
@@ -107,26 +117,32 @@ class S3Hook(AwsBaseHook):
     conn_type = 's3'
     hook_name = 'Amazon S3'
 
-    def __init__(self, *args, **kwargs) -> None:
+    def __init__(
+        self,
+        aws_conn_id: Optional[str] = AwsBaseHook.default_conn_name,
+        transfer_config_args: Optional[Dict] = None,
+        extra_args: Optional[Dict] = None,
+        *args,
+        **kwargs,
+    ) -> None:
         kwargs['client_type'] = 's3'
+        kwargs['aws_conn_id'] = aws_conn_id
+
+        if extra_args and not isinstance(extra_args, dict):
+            raise ValueError(f"transfer_config_args '{extra_args!r}' must be of type {dict}")
+        self.transfer_config = TransferConfig(**transfer_config_args or {})
 
-        self.extra_args = {}
-        if 'extra_args' in kwargs:
-            self.extra_args = kwargs['extra_args']
-            if not isinstance(self.extra_args, dict):
-                raise ValueError(f"extra_args '{self.extra_args!r}' must be of type {dict}")
-            del kwargs['extra_args']
-
-        self.transfer_config = TransferConfig()
-        if 'transfer_config_args' in kwargs:
-            transport_config_args = kwargs['transfer_config_args']
-            if not isinstance(transport_config_args, dict):
-                raise ValueError(f"transfer_config_args '{transport_config_args!r} must be of type {dict}")
-            self.transfer_config = TransferConfig(**transport_config_args)
-            del kwargs['transfer_config_args']
+        if extra_args and not isinstance(extra_args, dict):
+            raise ValueError(f"extra_args '{extra_args!r}' must be of type {dict}")
+        self._extra_args = extra_args or {}
 
         super().__init__(*args, **kwargs)
 
+    @property
+    def extra_args(self):
+        """Return hook's extra arguments (immutable)."""
+        return deepcopy(self._extra_args)
+
     @staticmethod
     def parse_s3_url(s3url: str) -> Tuple[str, str]:
         """
@@ -867,7 +883,11 @@ class S3Hook(AwsBaseHook):
                 raise e
 
         with NamedTemporaryFile(dir=local_path, prefix='airflow_tmp_', delete=False) as local_tmp_file:
-            s3_obj.download_fileobj(local_tmp_file)
+            s3_obj.download_fileobj(
+                local_tmp_file,
+                ExtraArgs=self.extra_args,
+                Config=self.transfer_config,
+            )
 
         return local_tmp_file.name
 
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index da63f37580..35f77a6d17 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -491,7 +491,11 @@ class TestAwsS3Hook:
         s3_hook.download_file(key=key, bucket_name=bucket)
 
         s3_hook.get_key.assert_called_once_with(key, bucket)
-        s3_obj.download_fileobj.assert_called_once_with(mock_temp_file)
+        s3_obj.download_fileobj.assert_called_once_with(
+            mock_temp_file,
+            Config=s3_hook.transfer_config,
+            ExtraArgs=s3_hook.extra_args,
+        )
 
     def test_generate_presigned_url(self, s3_bucket):
         hook = S3Hook()
@@ -525,6 +529,53 @@ class TestAwsS3Hook:
             resource = boto3.resource('s3').Object(s3_bucket, 'my_key')
             assert resource.get()['ContentLanguage'] == "value"
 
+    def test_that_extra_args_not_changed_between_calls(self, s3_bucket):
+        original = {
+            "Metadata": {"metakey": "metaval"},
+            "ACL": "private",
+            "ServerSideEncryption": "aws:kms",
+            "SSEKMSKeyId": "arn:aws:kms:region:acct-id:key/key-id",
+        }
+        s3_hook = S3Hook(aws_conn_id="s3_test", extra_args=original)
+        assert s3_hook.extra_args == original
+        assert s3_hook.extra_args is not original
+
+        dummy = mock.MagicMock()
+        s3_hook.check_for_key = Mock(return_value=False)
+        mock_upload_fileobj = s3_hook.conn.upload_fileobj = Mock(return_value=None)
+        mock_upload_file = s3_hook.conn.upload_file = Mock(return_value=None)
+
+        # First Call - load_file_obj.
+        s3_hook.load_file_obj(dummy, "mock_key", s3_bucket, encrypt=True, acl_policy="public-read")
+        first_call_extra_args = mock_upload_fileobj.call_args_list[0][1]["ExtraArgs"]
+        assert s3_hook.extra_args == original
+        assert first_call_extra_args is not s3_hook.extra_args
+
+        # Second Call - load_bytes.
+        s3_hook.load_string("dummy", "mock_key", s3_bucket, acl_policy="bucket-owner-full-control")
+        second_call_extra_args = mock_upload_fileobj.call_args_list[1][1]["ExtraArgs"]
+        assert s3_hook.extra_args == original
+        assert second_call_extra_args is not s3_hook.extra_args
+        assert second_call_extra_args != first_call_extra_args
+
+        # Third Call - load_string.
+        s3_hook.load_bytes(b"dummy", "mock_key", s3_bucket, encrypt=True)
+        third_call_extra_args = mock_upload_fileobj.call_args_list[2][1]["ExtraArgs"]
+        assert s3_hook.extra_args == original
+        assert third_call_extra_args is not s3_hook.extra_args
+        assert third_call_extra_args not in [first_call_extra_args, second_call_extra_args]
+
+        # Fourth Call - load_file.
+        s3_hook.load_file("/dummy.png", "mock_key", s3_bucket, encrypt=True, acl_policy="bucket-owner-read")
+        fourth_call_extra_args = mock_upload_file.call_args_list[0][1]["ExtraArgs"]
+        assert s3_hook.extra_args == original
+        assert fourth_call_extra_args is not s3_hook.extra_args
+        assert fourth_call_extra_args not in [
+            third_call_extra_args,
+            first_call_extra_args,
+            second_call_extra_args,
+        ]
+
     @mock_s3
     def test_get_bucket_tagging_no_tags_raises_error(self):
         hook = S3Hook()