You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@beam.apache.org by al...@apache.org on 2021/12/22 17:10:39 UTC

[beam] branch master updated: Make S3 streaming more efficient (#15931)

This is an automated email from the ASF dual-hosted git repository.

altay pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/beam.git


The following commit(s) were added to refs/heads/master by this push:
     new af2f8ee  Make S3 streaming more efficient (#15931)
af2f8ee is described below

commit af2f8ee6cf39a1d63818dbefef322740d5ad794c
Author: Janek Bevendorff <ja...@jbev.net>
AuthorDate: Wed Dec 22 18:09:17 2021 +0100

    Make S3 streaming more efficient (#15931)
---
 .../apache_beam/io/aws/clients/s3/boto3_client.py  | 117 +++++++++++----------
 sdks/python/apache_beam/utils/retry.py             |   2 +-
 2 files changed, 63 insertions(+), 56 deletions(-)

diff --git a/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py b/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py
index e2f968d..878ad07 100644
--- a/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py
+++ b/sdks/python/apache_beam/io/aws/clients/s3/boto3_client.py
@@ -19,6 +19,7 @@
 
 from apache_beam.io.aws.clients.s3 import messages
 from apache_beam.options import pipeline_options
+from apache_beam.utils import retry
 
 try:
   # pylint: disable=wrong-import-order, wrong-import-position
@@ -29,6 +30,12 @@ except ImportError:
   boto3 = None
 
 
+def get_http_error_code(exc):
+  if hasattr(exc, 'response'):
+    return exc.response.get('ResponseMetadata', {}).get('HTTPStatusCode')
+  return None
+
+
 class Client(object):
   """
   Wrapper for boto3 library
@@ -67,52 +74,67 @@ class Client(object):
         aws_secret_access_key=secret_access_key,
         aws_session_token=session_token)
 
-  def get_object_metadata(self, request):
-    r"""Retrieves an object's metadata.
+    self._download_request = None
+    self._download_stream = None
+    self._download_pos = 0
 
-    Args:
-      request: (GetRequest) input message
+  def get_stream(self, request, start):
+    """Opens a stream object starting at the given position.
 
+    Args:
+      request: (GetRequest) request
+      start: (int) start offset
     Returns:
-      (Object) The response message.
+      (Stream) Boto3 stream object.
     """
-    kwargs = {'Bucket': request.bucket, 'Key': request.object}
-
-    try:
-      boto_response = self.client.head_object(**kwargs)
-    except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
-
-    item = messages.Item(
-        boto_response['ETag'],
-        request.object,
-        boto_response['LastModified'],
-        boto_response['ContentLength'],
-        boto_response['ContentType'])
-
-    return item
 
+    if self._download_request and (
+        start != self._download_pos or
+        request.bucket != self._download_request.bucket or
+        request.object != self._download_request.object):
+      self._download_stream.close()
+      self._download_stream = None
+
+    # noinspection PyProtectedMember
+    if not self._download_stream or self._download_stream._raw_stream.closed:
+      try:
+        self._download_stream = self.client.get_object(
+            Bucket=request.bucket,
+            Key=request.object,
+            Range='bytes={}-'.format(start))['Body']
+        self._download_request = request
+        self._download_pos = start
+      except Exception as e:
+        raise messages.S3ClientError(str(e), get_http_error_code(e))
+
+    return self._download_stream
+
+  @retry.with_exponential_backoff()
   def get_range(self, request, start, end):
     r"""Retrieves an object's contents.
 
       Args:
         request: (GetRequest) request
+        start: (int) start offset
+        end: (int) end offset (exclusive)
       Returns:
         (bytes) The response message.
       """
-    try:
-      boto_response = self.client.get_object(
-          Bucket=request.bucket,
-          Key=request.object,
-          Range='bytes={}-{}'.format(start, end - 1))
-    except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
-
-    return boto_response['Body'].read()  # A bytes object
+    for i in range(2):
+      try:
+        stream = self.get_stream(request, start)
+        data = stream.read(end - start)
+        self._download_pos += len(data)
+        return data
+      except Exception as e:
+        self._download_stream = None
+        self._download_request = None
+        if i == 0:
+          # Read errors are likely with long-lived connections, retry immediately if a read fails once
+          continue
+        if isinstance(e, messages.S3ClientError):
+          raise e
+        raise messages.S3ClientError(str(e), get_http_error_code(e))
 
   def list(self, request):
     r"""Retrieves a list of objects matching the criteria.
@@ -130,9 +152,7 @@ class Client(object):
     try:
       boto_response = self.client.list_objects_v2(**kwargs)
     except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
+      raise messages.S3ClientError(str(e), get_http_error_code(e))
 
     if boto_response['KeyCount'] == 0:
       message = 'Tried to list nonexistent S3 path: s3://%s/%s' % (
@@ -170,9 +190,7 @@ class Client(object):
           ContentType=request.mime_type)
       response = messages.UploadResponse(boto_response['UploadId'])
     except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
+      raise messages.S3ClientError(str(e), get_http_error_code(e))
     return response
 
   def upload_part(self, request):
@@ -194,9 +212,7 @@ class Client(object):
           boto_response['ETag'], request.part_number)
       return response
     except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
+      raise messages.S3ClientError(str(e), get_http_error_code(e))
 
   def complete_multipart_upload(self, request):
     r"""Completes a multipart upload to S3
@@ -214,9 +230,7 @@ class Client(object):
           UploadId=request.upload_id,
           MultipartUpload=parts)
     except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
+      raise messages.S3ClientError(str(e), get_http_error_code(e))
 
   def delete(self, request):
     r"""Deletes given object from bucket
@@ -227,11 +241,8 @@ class Client(object):
     """
     try:
       self.client.delete_object(Bucket=request.bucket, Key=request.object)
-
     except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
+      raise messages.S3ClientError(str(e), get_http_error_code(e))
 
   def delete_batch(self, request):
 
@@ -247,9 +258,7 @@ class Client(object):
     try:
       aws_response = self.client.delete_objects(**aws_request)
     except Exception as e:
-      message = e.response['Error']['Message']
-      code = int(e.response['ResponseMetadata']['HTTPStatusCode'])
-      raise messages.S3ClientError(message, code)
+      raise messages.S3ClientError(str(e), get_http_error_code(e))
 
     deleted = [obj['Key'] for obj in aws_response.get('Deleted', [])]
 
@@ -267,6 +276,4 @@ class Client(object):
       copy_src = {'Bucket': request.src_bucket, 'Key': request.src_key}
       self.client.copy(copy_src, request.dest_bucket, request.dest_key)
     except Exception as e:
-      message = e.response['Error']['Message']
-      code = e.response['ResponseMetadata']['HTTPStatusCode']
-      raise messages.S3ClientError(message, code)
+      raise messages.S3ClientError(str(e), get_http_error_code(e))
diff --git a/sdks/python/apache_beam/utils/retry.py b/sdks/python/apache_beam/utils/retry.py
index 16f2882..4784de1 100644
--- a/sdks/python/apache_beam/utils/retry.py
+++ b/sdks/python/apache_beam/utils/retry.py
@@ -124,7 +124,7 @@ def retry_on_server_errors_filter(exception):
   if (HttpError is not None) and isinstance(exception, HttpError):
     return exception.status_code >= 500
   if (S3ClientError is not None) and isinstance(exception, S3ClientError):
-    return exception.code >= 500
+    return exception.code is None or exception.code >= 500
   return not isinstance(exception, PermanentException)