You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@airflow.apache.org by as...@apache.org on 2021/02/24 13:42:25 UTC

[airflow] branch master updated: Avoid using threads in S3 remote logging uplod (#14414)

This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/master by this push:
     new 0d6cae4  Avoid using threads in S3 remote logging uplod (#14414)
0d6cae4 is described below

commit 0d6cae4172ff185ec4c0fc483bf556ce3252b7b0
Author: Ruben Laguna <ru...@gmail.com>
AuthorDate: Wed Feb 24 14:42:13 2021 +0100

    Avoid using threads in S3 remote logging uplod (#14414)
    
    This prevents `RuntimeError: cannot schedule new futures after
    interpreter shutdown`
---
 airflow/providers/amazon/aws/hooks/s3.py             | 20 +++++++++++++++++---
 airflow/providers/amazon/aws/log/s3_task_handler.py  |  2 +-
 tests/providers/amazon/aws/hooks/test_s3.py          | 10 ++++++++++
 .../providers/amazon/aws/log/test_s3_task_handler.py |  1 +
 4 files changed, 29 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/s3.py b/airflow/providers/amazon/aws/hooks/s3.py
index de0202a..776a41c 100644
--- a/airflow/providers/amazon/aws/hooks/s3.py
+++ b/airflow/providers/amazon/aws/hooks/s3.py
@@ -30,7 +30,7 @@ from tempfile import NamedTemporaryFile
 from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union, cast
 from urllib.parse import urlparse
 
-from boto3.s3.transfer import S3Transfer
+from boto3.s3.transfer import S3Transfer, TransferConfig
 from botocore.exceptions import ClientError
 
 from airflow.exceptions import AirflowException
@@ -116,6 +116,14 @@ class S3Hook(AwsBaseHook):
                 raise ValueError(f"extra_args '{self.extra_args!r}' must be of type {dict}")
             del kwargs['extra_args']
 
+        self.transfer_config = TransferConfig()
+        if 'transfer_config_args' in kwargs:
+            transport_config_args = kwargs['transfer_config_args']
+            if not isinstance(transport_config_args, dict):
+                raise ValueError(f"transfer_config_args '{transport_config_args!r} must be of type {dict}")
+            self.transfer_config = TransferConfig(**transport_config_args)
+            del kwargs['transfer_config_args']
+
         super().__init__(*args, **kwargs)
 
     @staticmethod
@@ -502,7 +510,7 @@ class S3Hook(AwsBaseHook):
             extra_args['ACL'] = acl_policy
 
         client = self.get_conn()
-        client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args)
+        client.upload_file(filename, bucket_name, key, ExtraArgs=extra_args, Config=self.transfer_config)
 
     @provide_bucket_name
     @unify_bucket_name_and_key
@@ -651,7 +659,13 @@ class S3Hook(AwsBaseHook):
             extra_args['ACL'] = acl_policy
 
         client = self.get_conn()
-        client.upload_fileobj(file_obj, bucket_name, key, ExtraArgs=extra_args)
+        client.upload_fileobj(
+            file_obj,
+            bucket_name,
+            key,
+            ExtraArgs=extra_args,
+            Config=self.transfer_config,
+        )
 
     def copy_object(
         self,
diff --git a/airflow/providers/amazon/aws/log/s3_task_handler.py b/airflow/providers/amazon/aws/log/s3_task_handler.py
index 7fdeac3..663feb5 100644
--- a/airflow/providers/amazon/aws/log/s3_task_handler.py
+++ b/airflow/providers/amazon/aws/log/s3_task_handler.py
@@ -46,7 +46,7 @@ class S3TaskHandler(FileTaskHandler, LoggingMixin):
         try:
             from airflow.providers.amazon.aws.hooks.s3 import S3Hook
 
-            return S3Hook(remote_conn_id)
+            return S3Hook(remote_conn_id, transfer_config_args={"use_threads": False})
         except Exception as e:  # pylint: disable=broad-except
             self.log.exception(
                 'Could not create an S3Hook with connection id "%s". '
diff --git a/tests/providers/amazon/aws/hooks/test_s3.py b/tests/providers/amazon/aws/hooks/test_s3.py
index 8328b65..027a096 100644
--- a/tests/providers/amazon/aws/hooks/test_s3.py
+++ b/tests/providers/amazon/aws/hooks/test_s3.py
@@ -43,6 +43,16 @@ class TestAwsS3Hook:
         hook = S3Hook()
         assert hook.get_conn() is not None
 
+    @mock_s3
+    def test_use_threads_default_value(self):
+        hook = S3Hook()
+        assert hook.transfer_config.use_threads is True
+
+    @mock_s3
+    def test_use_threads_set_value(self):
+        hook = S3Hook(transfer_config_args={"use_threads": False})
+        assert hook.transfer_config.use_threads is False
+
     def test_parse_s3_url(self):
         parsed = S3Hook.parse_s3_url("s3://test/this/is/not/a-real-key.txt")
         assert parsed == ("test", "this/is/not/a-real-key.txt"), "Incorrect parsing of the s3 url"
diff --git a/tests/providers/amazon/aws/log/test_s3_task_handler.py b/tests/providers/amazon/aws/log/test_s3_task_handler.py
index d1437ca..e720081 100644
--- a/tests/providers/amazon/aws/log/test_s3_task_handler.py
+++ b/tests/providers/amazon/aws/log/test_s3_task_handler.py
@@ -80,6 +80,7 @@ class TestS3TaskHandler(unittest.TestCase):
 
     def test_hook(self):
         assert isinstance(self.s3_task_handler.hook, S3Hook)
+        assert self.s3_task_handler.hook.transfer_config.use_threads is False
 
     @conf_vars({('logging', 'remote_log_conn_id'): 'aws_default'})
     def test_hook_raises(self):