You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/08/15 11:30:36 UTC
[arrow-rs] branch master updated: Replace rusoto with custom implementation for AWS (#2176) (#2352)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 3f0e12d8d Replace rusoto with custom implementation for AWS (#2176) (#2352)
3f0e12d8d is described below
commit 3f0e12d8d362752181c75836d25d424862acc424
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Mon Aug 15 12:30:30 2022 +0100
Replace rusoto with custom implementation for AWS (#2176) (#2352)
* Replace rusoto (#2176)
* Add integration test for metadata endpoint
* Fix WebIdentity
* Fix doc
* Fix handling of multipart errors
* Use separate client for credentials
* Include port in Host header canonical request
* Fix doc link
* Review feedback
---
.github/workflows/object_store.yml | 12 +-
object_store/Cargo.toml | 11 +-
object_store/src/aws.rs | 1343 ---------------------------------
object_store/src/aws/client.rs | 483 ++++++++++++
object_store/src/aws/credential.rs | 590 +++++++++++++++
object_store/src/aws/mod.rs | 646 ++++++++++++++++
object_store/src/azure.rs | 85 +--
object_store/src/client/mod.rs | 2 +
object_store/src/client/pagination.rs | 70 ++
object_store/src/client/token.rs | 10 +-
object_store/src/gcp.rs | 219 +++---
object_store/src/lib.rs | 14 +-
object_store/src/multipart.rs | 59 +-
13 files changed, 1982 insertions(+), 1562 deletions(-)
diff --git a/.github/workflows/object_store.yml b/.github/workflows/object_store.yml
index 6c81604a9..5da2cb4e6 100644
--- a/.github/workflows/object_store.yml
+++ b/.github/workflows/object_store.yml
@@ -59,6 +59,13 @@ jobs:
image: localstack/localstack:0.14.4
ports:
- 4566:4566
+ ec2-metadata:
+ image: amazon/amazon-ec2-metadata-mock:v1.9.2
+ ports:
+ - 1338:1338
+ env:
+ # Only allow IMDSv2
+ AEMM_IMDSV2: "1"
azurite:
image: mcr.microsoft.com/azure-storage/azurite
ports:
@@ -78,6 +85,7 @@ jobs:
AWS_ACCESS_KEY_ID: test
AWS_SECRET_ACCESS_KEY: test
AWS_ENDPOINT: http://localstack:4566
+ EC2_METADATA_ENDPOINT: http://ec2-metadata:1338
AZURE_USE_EMULATOR: "1"
AZURITE_BLOB_STORAGE_URL: "http://azurite:10000"
AZURITE_QUEUE_STORAGE_URL: "http://azurite:10001"
@@ -101,8 +109,8 @@ jobs:
aws --endpoint-url=http://localstack:4566 s3 mb s3://test-bucket
- name: Configure Azurite (Azure emulation)
- # the magical connection string is from
- # https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio#http-connection-strings
+ # the magical connection string is from
+ # https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azurite?tabs=visual-studio#http-connection-strings
run: |
curl -sL https://aka.ms/InstallAzureCLIDeb | bash
az storage container create -n test-bucket --connection-string 'DefaultEndpointsProtocol=http;AccountName=devstoreaccount1;AccountKey=Eby8vdM02xNOcqFlqUwJPLlmEtlCDXJ1OUzFT50uSRZ6IFsuFq2UVErCz4I6tq/K1SZFPTOtr/KBHBeksoGMGw==;BlobEndpoint=http://azurite:10000/devstoreaccount1;QueueEndpoint=http://azurite:10001/devstoreaccount1;'
diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml
index bb371988a..8c713d80b 100644
--- a/object_store/Cargo.toml
+++ b/object_store/Cargo.toml
@@ -46,17 +46,8 @@ rustls-pemfile = { version = "1.0", default-features = false, optional = true }
ring = { version = "0.16", default-features = false, features = ["std"], optional = true }
base64 = { version = "0.13", default-features = false, optional = true }
rand = { version = "0.8", default-features = false, features = ["std", "std_rng"], optional = true }
-# for rusoto
-hyper = { version = "0.14", optional = true, default-features = false }
-# for rusoto
-hyper-rustls = { version = "0.23.0", optional = true, default-features = false, features = ["webpki-tokio", "http1", "http2", "tls12"] }
itertools = "0.10.1"
percent-encoding = "2.1"
-# rusoto crates are for Amazon S3 integration
-rusoto_core = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
-rusoto_credential = { version = "0.48.0", optional = true, default-features = false }
-rusoto_s3 = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
-rusoto_sts = { version = "0.48.0", optional = true, default-features = false, features = ["rustls"] }
snafu = "0.7"
tokio = { version = "1.18", features = ["sync", "macros", "parking_lot", "rt-multi-thread", "time", "io-util"] }
tracing = { version = "0.1" }
@@ -70,7 +61,7 @@ walkdir = "2"
azure = ["azure_core", "azure_storage_blobs", "azure_storage", "reqwest", "azure_identity"]
azure_test = ["azure", "azure_core/azurite_workaround", "azure_storage/azurite_workaround", "azure_storage_blobs/azurite_workaround"]
gcp = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64", "rand", "ring"]
-aws = ["rusoto_core", "rusoto_credential", "rusoto_s3", "rusoto_sts", "hyper", "hyper-rustls"]
+aws = ["serde", "serde_json", "quick-xml", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "rustls-pemfile", "base64", "rand", "ring"]
[dev-dependencies] # In alphabetical order
dotenv = "0.15.0"
diff --git a/object_store/src/aws.rs b/object_store/src/aws.rs
deleted file mode 100644
index bcb294c00..000000000
--- a/object_store/src/aws.rs
+++ /dev/null
@@ -1,1343 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements. See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership. The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License. You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied. See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-//! An object store implementation for S3
-//!
-//! ## Multi-part uploads
-//!
-//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method.
-//! Data passed to the writer is automatically buffered to meet the minimum size
-//! requirements for a part. Multiple parts are uploaded concurrently.
-//!
-//! If the writer fails for any reason, you may have parts uploaded to AWS but not
-//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method
-//! to abort the upload and drop those unneeded parts. In addition, you may wish to
-//! consider implementing [automatic cleanup] of unused parts that are older than one
-//! week.
-//!
-//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/
-use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart};
-use crate::util::format_http_range;
-use crate::MultipartId;
-use crate::{
- collect_bytes,
- path::{Path, DELIMITER},
- util::format_prefix,
- GetResult, ListResult, ObjectMeta, ObjectStore, Result,
-};
-use async_trait::async_trait;
-use bytes::Bytes;
-use chrono::{DateTime, Utc};
-use futures::future::BoxFuture;
-use futures::{
- stream::{self, BoxStream},
- Future, Stream, StreamExt, TryStreamExt,
-};
-use hyper::client::Builder as HyperBuilder;
-use percent_encoding::{percent_encode, AsciiSet, NON_ALPHANUMERIC};
-use rusoto_core::ByteStream;
-use rusoto_credential::{InstanceMetadataProvider, StaticProvider};
-use rusoto_s3::S3;
-use rusoto_sts::WebIdentityProvider;
-use snafu::{OptionExt, ResultExt, Snafu};
-use std::io;
-use std::ops::Range;
-use std::{
- convert::TryFrom, fmt, num::NonZeroUsize, ops::Deref, sync::Arc, time::Duration,
-};
-use tokio::io::AsyncWrite;
-use tokio::sync::{OwnedSemaphorePermit, Semaphore};
-use tracing::{debug, warn};
-
-// Do not URI-encode any of the unreserved characters that RFC 3986 defines:
-// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ).
-const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC
- .remove(b'-')
- .remove(b'.')
- .remove(b'_')
- .remove(b'~');
-
-/// This struct is used to maintain the URI path encoding
-const STRICT_PATH_ENCODE_SET: AsciiSet = STRICT_ENCODE_SET.remove(b'/');
-
-/// The maximum number of times a request will be retried in the case of an AWS server error
-pub const MAX_NUM_RETRIES: u32 = 3;
-
-/// A specialized `Error` for object store-related errors
-#[derive(Debug, Snafu)]
-#[allow(missing_docs)]
-enum Error {
- #[snafu(display(
- "Expected streamed data to have length {}, got {}",
- expected,
- actual
- ))]
- DataDoesNotMatchLength { expected: usize, actual: usize },
-
- #[snafu(display(
- "Did not receive any data. Bucket: {}, Location: {}",
- bucket,
- path
- ))]
- NoData { bucket: String, path: String },
-
- #[snafu(display(
- "Unable to DELETE data. Bucket: {}, Location: {}, Error: {} ({:?})",
- bucket,
- path,
- source,
- source,
- ))]
- UnableToDeleteData {
- source: rusoto_core::RusotoError<rusoto_s3::DeleteObjectError>,
- bucket: String,
- path: String,
- },
-
- #[snafu(display(
- "Unable to GET data. Bucket: {}, Location: {}, Error: {} ({:?})",
- bucket,
- path,
- source,
- source,
- ))]
- UnableToGetData {
- source: rusoto_core::RusotoError<rusoto_s3::GetObjectError>,
- bucket: String,
- path: String,
- },
-
- #[snafu(display(
- "Unable to HEAD data. Bucket: {}, Location: {}, Error: {} ({:?})",
- bucket,
- path,
- source,
- source,
- ))]
- UnableToHeadData {
- source: rusoto_core::RusotoError<rusoto_s3::HeadObjectError>,
- bucket: String,
- path: String,
- },
-
- #[snafu(display(
- "Unable to GET part of the data. Bucket: {}, Location: {}, Error: {} ({:?})",
- bucket,
- path,
- source,
- source,
- ))]
- UnableToGetPieceOfData {
- source: std::io::Error,
- bucket: String,
- path: String,
- },
-
- #[snafu(display(
- "Unable to PUT data. Bucket: {}, Location: {}, Error: {} ({:?})",
- bucket,
- path,
- source,
- source,
- ))]
- UnableToPutData {
- source: rusoto_core::RusotoError<rusoto_s3::PutObjectError>,
- bucket: String,
- path: String,
- },
-
- #[snafu(display(
- "Unable to upload data. Bucket: {}, Location: {}, Error: {} ({:?})",
- bucket,
- path,
- source,
- source,
- ))]
- UnableToUploadData {
- source: rusoto_core::RusotoError<rusoto_s3::CreateMultipartUploadError>,
- bucket: String,
- path: String,
- },
-
- #[snafu(display(
- "Unable to cleanup multipart data. Bucket: {}, Location: {}, Error: {} ({:?})",
- bucket,
- path,
- source,
- source,
- ))]
- UnableToCleanupMultipartData {
- source: rusoto_core::RusotoError<rusoto_s3::AbortMultipartUploadError>,
- bucket: String,
- path: String,
- },
-
- #[snafu(display(
- "Unable to list data. Bucket: {}, Error: {} ({:?})",
- bucket,
- source,
- source,
- ))]
- UnableToListData {
- source: rusoto_core::RusotoError<rusoto_s3::ListObjectsV2Error>,
- bucket: String,
- },
-
- #[snafu(display(
- "Unable to copy object. Bucket: {}, From: {}, To: {}, Error: {}",
- bucket,
- from,
- to,
- source,
- ))]
- UnableToCopyObject {
- source: rusoto_core::RusotoError<rusoto_s3::CopyObjectError>,
- bucket: String,
- from: String,
- to: String,
- },
-
- #[snafu(display(
- "Unable to parse last modified date. Bucket: {}, Error: {} ({:?})",
- bucket,
- source,
- source,
- ))]
- UnableToParseLastModified {
- source: chrono::ParseError,
- bucket: String,
- },
-
- #[snafu(display(
- "Unable to buffer data into temporary file, Error: {} ({:?})",
- source,
- source,
- ))]
- UnableToBufferStream { source: std::io::Error },
-
- #[snafu(display(
- "Could not parse `{}` as an AWS region. Regions should look like `us-east-2`. {} ({:?})",
- region,
- source,
- source,
- ))]
- InvalidRegion {
- region: String,
- source: rusoto_core::region::ParseRegionError,
- },
-
- #[snafu(display(
- "Region must be specified for AWS S3. Regions should look like `us-east-2`"
- ))]
- MissingRegion {},
-
- #[snafu(display("Missing bucket name"))]
- MissingBucketName {},
-
- #[snafu(display("Missing aws-access-key"))]
- MissingAccessKey,
-
- #[snafu(display("Missing aws-secret-access-key"))]
- MissingSecretAccessKey,
-
- NotFound {
- path: String,
- source: Box<dyn std::error::Error + Send + Sync + 'static>,
- },
-}
-
-impl From<Error> for super::Error {
- fn from(source: Error) -> Self {
- match source {
- Error::NotFound { path, source } => Self::NotFound { path, source },
- _ => Self::Generic {
- store: "S3",
- source: Box::new(source),
- },
- }
- }
-}
-
-/// Interface for [Amazon S3](https://aws.amazon.com/s3/).
-pub struct AmazonS3 {
- /// S3 client w/o any connection limit.
- ///
- /// You should normally use [`Self::client`] instead.
- client_unrestricted: rusoto_s3::S3Client,
-
- /// Semaphore that limits the usage of [`client_unrestricted`](Self::client_unrestricted).
- connection_semaphore: Arc<Semaphore>,
-
- /// Bucket name used by this object store client.
- bucket_name: String,
-}
-
-impl fmt::Debug for AmazonS3 {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("AmazonS3")
- .field("client", &"rusoto_s3::S3Client")
- .field("bucket_name", &self.bucket_name)
- .finish()
- }
-}
-
-impl fmt::Display for AmazonS3 {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- write!(f, "AmazonS3({})", self.bucket_name)
- }
-}
-
-#[async_trait]
-impl ObjectStore for AmazonS3 {
- async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
- let bucket_name = self.bucket_name.clone();
- let request_factory = move || {
- let bytes = bytes.clone();
-
- let length = bytes.len();
- let stream_data = Ok(bytes);
- let stream = futures::stream::once(async move { stream_data });
- let byte_stream = ByteStream::new_with_size(stream, length);
-
- rusoto_s3::PutObjectRequest {
- bucket: bucket_name.clone(),
- key: location.to_string(),
- body: Some(byte_stream),
- ..Default::default()
- }
- };
-
- let s3 = self.client().await;
-
- s3_request(move || {
- let (s3, request_factory) = (s3.clone(), request_factory.clone());
-
- async move { s3.put_object(request_factory()).await }
- })
- .await
- .context(UnableToPutDataSnafu {
- bucket: &self.bucket_name,
- path: location.as_ref(),
- })?;
-
- Ok(())
- }
-
- async fn put_multipart(
- &self,
- location: &Path,
- ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
- let bucket_name = self.bucket_name.clone();
-
- let request_factory = move || rusoto_s3::CreateMultipartUploadRequest {
- bucket: bucket_name.clone(),
- key: location.to_string(),
- ..Default::default()
- };
-
- let s3 = self.client().await;
-
- let data = s3_request(move || {
- let (s3, request_factory) = (s3.clone(), request_factory.clone());
-
- async move { s3.create_multipart_upload(request_factory()).await }
- })
- .await
- .context(UnableToUploadDataSnafu {
- bucket: &self.bucket_name,
- path: location.as_ref(),
- })?;
-
- let upload_id = data.upload_id.unwrap();
-
- let inner = S3MultiPartUpload {
- upload_id: upload_id.clone(),
- bucket: self.bucket_name.clone(),
- key: location.to_string(),
- client_unrestricted: self.client_unrestricted.clone(),
- connection_semaphore: Arc::clone(&self.connection_semaphore),
- };
-
- Ok((upload_id, Box::new(CloudMultiPartUpload::new(inner, 8))))
- }
-
- async fn abort_multipart(
- &self,
- location: &Path,
- multipart_id: &MultipartId,
- ) -> Result<()> {
- let request_factory = move || rusoto_s3::AbortMultipartUploadRequest {
- bucket: self.bucket_name.clone(),
- key: location.to_string(),
- upload_id: multipart_id.to_string(),
- ..Default::default()
- };
-
- let s3 = self.client().await;
- s3_request(move || {
- let (s3, request_factory) = (s3.clone(), request_factory);
-
- async move { s3.abort_multipart_upload(request_factory()).await }
- })
- .await
- .context(UnableToCleanupMultipartDataSnafu {
- bucket: &self.bucket_name,
- path: location.as_ref(),
- })?;
-
- Ok(())
- }
-
- async fn get(&self, location: &Path) -> Result<GetResult> {
- Ok(GetResult::Stream(
- self.get_object(location, None).await?.boxed(),
- ))
- }
-
- async fn get_range(&self, location: &Path, range: Range<usize>) -> Result<Bytes> {
- let size_hint = range.end - range.start;
- let stream = self.get_object(location, Some(range)).await?;
- collect_bytes(stream, Some(size_hint)).await
- }
-
- async fn head(&self, location: &Path) -> Result<ObjectMeta> {
- let key = location.to_string();
- let head_request = rusoto_s3::HeadObjectRequest {
- bucket: self.bucket_name.clone(),
- key: key.clone(),
- ..Default::default()
- };
- let s = self
- .client()
- .await
- .head_object(head_request)
- .await
- .map_err(|e| match e {
- rusoto_core::RusotoError::Service(
- rusoto_s3::HeadObjectError::NoSuchKey(_),
- ) => Error::NotFound {
- path: key.clone(),
- source: e.into(),
- },
- rusoto_core::RusotoError::Unknown(h) if h.status.as_u16() == 404 => {
- Error::NotFound {
- path: key.clone(),
- source: "resource not found".into(),
- }
- }
- _ => Error::UnableToHeadData {
- bucket: self.bucket_name.to_owned(),
- path: key.clone(),
- source: e,
- },
- })?;
-
- // Note: GetObject and HeadObject return a different date format from ListObjects
- //
- // S3 List returns timestamps in the form
- // <LastModified>2013-09-17T18:07:53.000Z</LastModified>
- // S3 GetObject returns timestamps in the form
- // Last-Modified: Sun, 1 Jan 2006 12:00:00 GMT
- let last_modified = match s.last_modified {
- Some(lm) => DateTime::parse_from_rfc2822(&lm)
- .context(UnableToParseLastModifiedSnafu {
- bucket: &self.bucket_name,
- })?
- .with_timezone(&Utc),
- None => Utc::now(),
- };
-
- Ok(ObjectMeta {
- last_modified,
- location: location.clone(),
- size: usize::try_from(s.content_length.unwrap_or(0))
- .expect("unsupported size on this platform"),
- })
- }
-
- async fn delete(&self, location: &Path) -> Result<()> {
- let bucket_name = self.bucket_name.clone();
-
- let request_factory = move || rusoto_s3::DeleteObjectRequest {
- bucket: bucket_name.clone(),
- key: location.to_string(),
- ..Default::default()
- };
-
- let s3 = self.client().await;
-
- s3_request(move || {
- let (s3, request_factory) = (s3.clone(), request_factory.clone());
-
- async move { s3.delete_object(request_factory()).await }
- })
- .await
- .context(UnableToDeleteDataSnafu {
- bucket: &self.bucket_name,
- path: location.as_ref(),
- })?;
-
- Ok(())
- }
-
- async fn list(
- &self,
- prefix: Option<&Path>,
- ) -> Result<BoxStream<'_, Result<ObjectMeta>>> {
- Ok(self
- .list_objects_v2(prefix, None)
- .await?
- .map_ok(move |list_objects_v2_result| {
- let contents = list_objects_v2_result.contents.unwrap_or_default();
- let iter = contents
- .into_iter()
- .map(|object| convert_object_meta(object, &self.bucket_name));
-
- futures::stream::iter(iter)
- })
- .try_flatten()
- .boxed())
- }
-
- async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
- Ok(self
- .list_objects_v2(prefix, Some(DELIMITER.to_string()))
- .await?
- .try_fold(
- ListResult {
- common_prefixes: vec![],
- objects: vec![],
- },
- |acc, list_objects_v2_result| async move {
- let mut res = acc;
- let contents = list_objects_v2_result.contents.unwrap_or_default();
- let mut objects = contents
- .into_iter()
- .map(|object| convert_object_meta(object, &self.bucket_name))
- .collect::<Result<Vec<_>>>()?;
-
- res.objects.append(&mut objects);
-
- let prefixes =
- list_objects_v2_result.common_prefixes.unwrap_or_default();
- res.common_prefixes.reserve(prefixes.len());
-
- for p in prefixes {
- let prefix =
- p.prefix.expect("can't have a prefix without a value");
- res.common_prefixes.push(Path::parse(prefix)?);
- }
-
- Ok(res)
- },
- )
- .await?)
- }
-
- async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
- let from = from.as_ref();
- let to = to.as_ref();
- let bucket_name = self.bucket_name.clone();
-
- let copy_source = format!(
- "{}/{}",
- &bucket_name,
- percent_encode(from.as_ref(), &STRICT_PATH_ENCODE_SET)
- );
-
- let request_factory = move || rusoto_s3::CopyObjectRequest {
- bucket: bucket_name.clone(),
- copy_source,
- key: to.to_string(),
- ..Default::default()
- };
-
- let s3 = self.client().await;
-
- s3_request(move || {
- let (s3, request_factory) = (s3.clone(), request_factory.clone());
-
- async move { s3.copy_object(request_factory()).await }
- })
- .await
- .context(UnableToCopyObjectSnafu {
- bucket: &self.bucket_name,
- from,
- to,
- })?;
-
- Ok(())
- }
-
- async fn copy_if_not_exists(&self, _source: &Path, _dest: &Path) -> Result<()> {
- // Will need dynamodb_lock
- Err(crate::Error::NotImplemented)
- }
-}
-
-fn convert_object_meta(object: rusoto_s3::Object, bucket: &str) -> Result<ObjectMeta> {
- let key = object.key.expect("object doesn't exist without a key");
- let location = Path::parse(key)?;
- let last_modified = match object.last_modified {
- Some(lm) => DateTime::parse_from_rfc3339(&lm)
- .context(UnableToParseLastModifiedSnafu { bucket })?
- .with_timezone(&Utc),
- None => Utc::now(),
- };
- let size = usize::try_from(object.size.unwrap_or(0))
- .expect("unsupported size on this platform");
-
- Ok(ObjectMeta {
- location,
- last_modified,
- size,
- })
-}
-
-/// Configure a connection to Amazon S3 using the specified credentials in
-/// the specified Amazon region and bucket.
-///
-/// # Example
-/// ```
-/// # let REGION = "foo";
-/// # let BUCKET_NAME = "foo";
-/// # let ACCESS_KEY_ID = "foo";
-/// # let SECRET_KEY = "foo";
-/// # use object_store::aws::AmazonS3Builder;
-/// let s3 = AmazonS3Builder::new()
-/// .with_region(REGION)
-/// .with_bucket_name(BUCKET_NAME)
-/// .with_access_key_id(ACCESS_KEY_ID)
-/// .with_secret_access_key(SECRET_KEY)
-/// .build();
-/// ```
-#[derive(Debug)]
-pub struct AmazonS3Builder {
- access_key_id: Option<String>,
- secret_access_key: Option<String>,
- region: Option<String>,
- bucket_name: Option<String>,
- endpoint: Option<String>,
- token: Option<String>,
- max_connections: NonZeroUsize,
- allow_http: bool,
-}
-
-impl Default for AmazonS3Builder {
- fn default() -> Self {
- Self {
- access_key_id: None,
- secret_access_key: None,
- region: None,
- bucket_name: None,
- endpoint: None,
- token: None,
- max_connections: NonZeroUsize::new(16).unwrap(),
- allow_http: false,
- }
- }
-}
-
-impl AmazonS3Builder {
- /// Create a new [`AmazonS3Builder`] with default values.
- pub fn new() -> Self {
- Default::default()
- }
-
- /// Set the AWS Access Key (required)
- pub fn with_access_key_id(mut self, access_key_id: impl Into<String>) -> Self {
- self.access_key_id = Some(access_key_id.into());
- self
- }
-
- /// Set the AWS Secret Access Key (required)
- pub fn with_secret_access_key(
- mut self,
- secret_access_key: impl Into<String>,
- ) -> Self {
- self.secret_access_key = Some(secret_access_key.into());
- self
- }
-
- /// Set the region (e.g. `us-east-1`) (required)
- pub fn with_region(mut self, region: impl Into<String>) -> Self {
- self.region = Some(region.into());
- self
- }
-
- /// Set the bucket_name (required)
- pub fn with_bucket_name(mut self, bucket_name: impl Into<String>) -> Self {
- self.bucket_name = Some(bucket_name.into());
- self
- }
-
- /// Sets the endpoint for communicating with AWS S3. Default value
- /// is based on region.
- ///
- /// For example, this might be set to `"http://localhost:4566:`
- /// for testing against a localstack instance.
- pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
- self.endpoint = Some(endpoint.into());
- self
- }
-
- /// Set the token to use for requests (passed to underlying provider)
- pub fn with_token(mut self, token: impl Into<String>) -> Self {
- self.token = Some(token.into());
- self
- }
-
- /// Sets the maximum number of concurrent outstanding
- /// connectons. Default is `16`.
- #[deprecated(note = "use LimitStore instead")]
- pub fn with_max_connections(mut self, max_connections: NonZeroUsize) -> Self {
- self.max_connections = max_connections;
- self
- }
-
- /// Sets what protocol is allowed. If `allow_http` is :
- /// * false (default): Only HTTPS are allowed
- /// * true: HTTP and HTTPS are allowed
- pub fn with_allow_http(mut self, allow_http: bool) -> Self {
- self.allow_http = allow_http;
- self
- }
-
- /// Create a [`AmazonS3`] instance from the provided values,
- /// consuming `self`.
- pub fn build(self) -> Result<AmazonS3> {
- let Self {
- access_key_id,
- secret_access_key,
- region,
- bucket_name,
- endpoint,
- token,
- max_connections,
- allow_http,
- } = self;
-
- let region = region.ok_or(Error::MissingRegion {})?;
- let bucket_name = bucket_name.ok_or(Error::MissingBucketName {})?;
-
- let region: rusoto_core::Region = match endpoint {
- None => region.parse().context(InvalidRegionSnafu { region })?,
- Some(endpoint) => rusoto_core::Region::Custom {
- name: region,
- endpoint,
- },
- };
-
- let mut builder = HyperBuilder::default();
- builder.pool_max_idle_per_host(max_connections.get());
-
- let connector = if allow_http {
- hyper_rustls::HttpsConnectorBuilder::new()
- .with_webpki_roots()
- .https_or_http()
- .enable_http1()
- .enable_http2()
- .build()
- } else {
- hyper_rustls::HttpsConnectorBuilder::new()
- .with_webpki_roots()
- .https_only()
- .enable_http1()
- .enable_http2()
- .build()
- };
-
- let http_client =
- rusoto_core::request::HttpClient::from_builder(builder, connector);
-
- let client = match (access_key_id, secret_access_key, token) {
- (Some(access_key_id), Some(secret_access_key), Some(token)) => {
- let credentials_provider = StaticProvider::new(
- access_key_id,
- secret_access_key,
- Some(token),
- None,
- );
- rusoto_s3::S3Client::new_with(http_client, credentials_provider, region)
- }
- (Some(access_key_id), Some(secret_access_key), None) => {
- let credentials_provider =
- StaticProvider::new_minimal(access_key_id, secret_access_key);
- rusoto_s3::S3Client::new_with(http_client, credentials_provider, region)
- }
- (None, Some(_), _) => return Err(Error::MissingAccessKey.into()),
- (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()),
- _ if std::env::var_os("AWS_WEB_IDENTITY_TOKEN_FILE").is_some() => {
- rusoto_s3::S3Client::new_with(
- http_client,
- WebIdentityProvider::from_k8s_env(),
- region,
- )
- }
- _ => rusoto_s3::S3Client::new_with(
- http_client,
- InstanceMetadataProvider::new(),
- region,
- ),
- };
-
- Ok(AmazonS3 {
- client_unrestricted: client,
- connection_semaphore: Arc::new(Semaphore::new(max_connections.get())),
- bucket_name,
- })
- }
-}
-
-/// S3 client bundled w/ a semaphore permit.
-#[derive(Clone)]
-struct SemaphoreClient {
- /// Permit for this specific use of the client.
- ///
- /// Note that this field is never read and therefore considered "dead code" by rustc.
- #[allow(dead_code)]
- permit: Arc<OwnedSemaphorePermit>,
-
- inner: rusoto_s3::S3Client,
-}
-
-impl Deref for SemaphoreClient {
- type Target = rusoto_s3::S3Client;
-
- fn deref(&self) -> &Self::Target {
- &self.inner
- }
-}
-
-impl AmazonS3 {
- /// Get a client according to the current connection limit.
- async fn client(&self) -> SemaphoreClient {
- let permit = Arc::clone(&self.connection_semaphore)
- .acquire_owned()
- .await
- .expect("semaphore shouldn't be closed yet");
- SemaphoreClient {
- permit: Arc::new(permit),
- inner: self.client_unrestricted.clone(),
- }
- }
-
- async fn get_object(
- &self,
- location: &Path,
- range: Option<Range<usize>>,
- ) -> Result<impl Stream<Item = Result<Bytes>>> {
- let key = location.to_string();
- let get_request = rusoto_s3::GetObjectRequest {
- bucket: self.bucket_name.clone(),
- key: key.clone(),
- range: range.map(format_http_range),
- ..Default::default()
- };
- let bucket_name = self.bucket_name.clone();
- let stream = self
- .client()
- .await
- .get_object(get_request)
- .await
- .map_err(|e| match e {
- rusoto_core::RusotoError::Service(
- rusoto_s3::GetObjectError::NoSuchKey(_),
- ) => Error::NotFound {
- path: key.clone(),
- source: e.into(),
- },
- _ => Error::UnableToGetData {
- bucket: self.bucket_name.to_owned(),
- path: key.clone(),
- source: e,
- },
- })?
- .body
- .context(NoDataSnafu {
- bucket: self.bucket_name.to_owned(),
- path: key.clone(),
- })?
- .map_err(move |source| Error::UnableToGetPieceOfData {
- source,
- bucket: bucket_name.clone(),
- path: key.clone(),
- })
- .err_into();
-
- Ok(stream)
- }
-
- async fn list_objects_v2(
- &self,
- prefix: Option<&Path>,
- delimiter: Option<String>,
- ) -> Result<BoxStream<'_, Result<rusoto_s3::ListObjectsV2Output>>> {
- enum ListState {
- Start,
- HasMore(String),
- Done,
- }
-
- let prefix = format_prefix(prefix);
- let bucket = self.bucket_name.clone();
-
- let request_factory = move || rusoto_s3::ListObjectsV2Request {
- bucket,
- prefix,
- delimiter,
- ..Default::default()
- };
- let s3 = self.client().await;
-
- Ok(stream::unfold(ListState::Start, move |state| {
- let request_factory = request_factory.clone();
- let s3 = s3.clone();
-
- async move {
- let continuation_token = match &state {
- ListState::HasMore(continuation_token) => Some(continuation_token),
- ListState::Done => {
- return None;
- }
- // If this is the first request we've made, we don't need to make any
- // modifications to the request
- ListState::Start => None,
- };
-
- let resp = s3_request(move || {
- let (s3, request_factory, continuation_token) = (
- s3.clone(),
- request_factory.clone(),
- continuation_token.cloned(),
- );
-
- async move {
- s3.list_objects_v2(rusoto_s3::ListObjectsV2Request {
- continuation_token,
- ..request_factory()
- })
- .await
- }
- })
- .await;
-
- let resp = match resp {
- Ok(resp) => resp,
- Err(e) => return Some((Err(e), state)),
- };
-
- // The AWS response contains a field named `is_truncated` as well as
- // `next_continuation_token`, and we're assuming that `next_continuation_token`
- // is only set when `is_truncated` is true (and therefore not
- // checking `is_truncated`).
- let next_state = if let Some(next_continuation_token) =
- &resp.next_continuation_token
- {
- ListState::HasMore(next_continuation_token.to_string())
- } else {
- ListState::Done
- };
-
- Some((Ok(resp), next_state))
- }
- })
- .map_err(move |e| {
- Error::UnableToListData {
- source: e,
- bucket: self.bucket_name.clone(),
- }
- .into()
- })
- .boxed())
- }
-}
-
-/// Handles retrying a request to S3 up to `MAX_NUM_RETRIES` times if S3 returns 5xx server errors.
-///
-/// The `future_factory` argument is a function `F` that takes no arguments and, when called, will
-/// return a `Future` (type `G`) that, when `await`ed, will perform a request to S3 through
-/// `rusoto` and return a `Result` that returns some type `R` on success and some
-/// `rusoto_core::RusotoError<E>` on error.
-///
-/// If the executed `Future` returns success, this function will return that success.
-/// If the executed `Future` returns a 5xx server error, this function will wait an amount of
-/// time that increases exponentially with the number of times it has retried, get a new `Future` by
-/// calling `future_factory` again, and retry the request by `await`ing the `Future` again.
-/// The retries will continue until the maximum number of retries has been attempted. In that case,
-/// this function will return the last encountered error.
-///
-/// Client errors (4xx) will never be retried by this function.
-async fn s3_request<E, F, G, R>(
- future_factory: F,
-) -> Result<R, rusoto_core::RusotoError<E>>
-where
- E: std::error::Error + Send,
- F: Fn() -> G + Send,
- G: Future<Output = Result<R, rusoto_core::RusotoError<E>>> + Send,
- R: Send,
-{
- let mut attempts = 0;
-
- loop {
- let request = future_factory();
-
- let result = request.await;
-
- match result {
- Ok(r) => return Ok(r),
- Err(error) => {
- attempts += 1;
-
- let should_retry = matches!(
- error,
- rusoto_core::RusotoError::Unknown(ref response)
- if response.status.is_server_error()
- );
-
- if attempts > MAX_NUM_RETRIES {
- warn!(
- ?error,
- attempts, "maximum number of retries exceeded for AWS S3 request"
- );
- return Err(error);
- } else if !should_retry {
- return Err(error);
- } else {
- debug!(?error, attempts, "retrying AWS S3 request");
- let wait_time = Duration::from_millis(2u64.pow(attempts) * 50);
- tokio::time::sleep(wait_time).await;
- }
- }
- }
- }
-}
-
-struct S3MultiPartUpload {
- bucket: String,
- key: String,
- upload_id: String,
- client_unrestricted: rusoto_s3::S3Client,
- connection_semaphore: Arc<Semaphore>,
-}
-
-impl CloudMultiPartUploadImpl for S3MultiPartUpload {
- fn put_multipart_part(
- &self,
- buf: Vec<u8>,
- part_idx: usize,
- ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> {
- // Get values to move into future; we don't want a reference to Self
- let bucket = self.bucket.clone();
- let key = self.key.clone();
- let upload_id = self.upload_id.clone();
- let content_length = buf.len();
-
- let request_factory = move || rusoto_s3::UploadPartRequest {
- bucket,
- key,
- upload_id,
- // AWS part number is 1-indexed
- part_number: (part_idx + 1).try_into().unwrap(),
- content_length: Some(content_length.try_into().unwrap()),
- body: Some(buf.into()),
- ..Default::default()
- };
-
- let s3 = self.client_unrestricted.clone();
- let connection_semaphore = Arc::clone(&self.connection_semaphore);
-
- Box::pin(async move {
- let _permit = connection_semaphore
- .acquire_owned()
- .await
- .expect("semaphore shouldn't be closed yet");
-
- let response = s3_request(move || {
- let (s3, request_factory) = (s3.clone(), request_factory.clone());
- async move { s3.upload_part(request_factory()).await }
- })
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
-
- Ok((
- part_idx,
- UploadPart {
- content_id: response.e_tag.unwrap(),
- },
- ))
- })
- }
-
- fn complete(
- &self,
- completed_parts: Vec<Option<UploadPart>>,
- ) -> BoxFuture<'static, Result<(), io::Error>> {
- let parts =
- completed_parts
- .into_iter()
- .enumerate()
- .map(|(part_number, maybe_part)| match maybe_part {
- Some(part) => {
- Ok(rusoto_s3::CompletedPart {
- e_tag: Some(part.content_id),
- part_number: Some((part_number + 1).try_into().map_err(
- |err| io::Error::new(io::ErrorKind::Other, err),
- )?),
- })
- }
- None => Err(io::Error::new(
- io::ErrorKind::Other,
- format!("Missing information for upload part {:?}", part_number),
- )),
- });
-
- // Get values to move into future; we don't want a reference to Self
- let bucket = self.bucket.clone();
- let key = self.key.clone();
- let upload_id = self.upload_id.clone();
-
- let request_factory = move || -> Result<_, io::Error> {
- Ok(rusoto_s3::CompleteMultipartUploadRequest {
- bucket,
- key,
- upload_id,
- multipart_upload: Some(rusoto_s3::CompletedMultipartUpload {
- parts: Some(parts.collect::<Result<_, io::Error>>()?),
- }),
- ..Default::default()
- })
- };
-
- let s3 = self.client_unrestricted.clone();
- let connection_semaphore = Arc::clone(&self.connection_semaphore);
-
- Box::pin(async move {
- let _permit = connection_semaphore
- .acquire_owned()
- .await
- .expect("semaphore shouldn't be closed yet");
-
- s3_request(move || {
- let (s3, request_factory) = (s3.clone(), request_factory.clone());
-
- async move { s3.complete_multipart_upload(request_factory()?).await }
- })
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
-
- Ok(())
- })
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::{
- tests::{
- get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter,
- put_get_delete_list, rename_and_copy, stream_get,
- },
- Error as ObjectStoreError,
- };
- use bytes::Bytes;
- use std::env;
-
- const NON_EXISTENT_NAME: &str = "nonexistentname";
-
- // Helper macro to skip tests if TEST_INTEGRATION and the AWS
- // environment variables are not set. Returns a configured
- // AmazonS3Builder
- macro_rules! maybe_skip_integration {
- () => {{
- dotenv::dotenv().ok();
-
- let required_vars = [
- "AWS_DEFAULT_REGION",
- "OBJECT_STORE_BUCKET",
- "AWS_ACCESS_KEY_ID",
- "AWS_SECRET_ACCESS_KEY",
- ];
- let unset_vars: Vec<_> = required_vars
- .iter()
- .filter_map(|&name| match env::var(name) {
- Ok(_) => None,
- Err(_) => Some(name),
- })
- .collect();
- let unset_var_names = unset_vars.join(", ");
-
- let force = env::var("TEST_INTEGRATION");
-
- if force.is_ok() && !unset_var_names.is_empty() {
- panic!(
- "TEST_INTEGRATION is set, \
- but variable(s) {} need to be set",
- unset_var_names
- );
- } else if force.is_err() {
- eprintln!(
- "skipping AWS integration test - set {}TEST_INTEGRATION to run",
- if unset_var_names.is_empty() {
- String::new()
- } else {
- format!("{} and ", unset_var_names)
- }
- );
- return;
- } else {
- let config = AmazonS3Builder::new()
- .with_access_key_id(
- env::var("AWS_ACCESS_KEY_ID")
- .expect("already checked AWS_ACCESS_KEY_ID"),
- )
- .with_secret_access_key(
- env::var("AWS_SECRET_ACCESS_KEY")
- .expect("already checked AWS_SECRET_ACCESS_KEY"),
- )
- .with_region(
- env::var("AWS_DEFAULT_REGION")
- .expect("already checked AWS_DEFAULT_REGION"),
- )
- .with_bucket_name(
- env::var("OBJECT_STORE_BUCKET")
- .expect("already checked OBJECT_STORE_BUCKET"),
- )
- .with_allow_http(true);
-
- let config = if let Some(endpoint) = env::var("AWS_ENDPOINT").ok() {
- config.with_endpoint(endpoint)
- } else {
- config
- };
-
- let config = if let Some(token) = env::var("AWS_SESSION_TOKEN").ok() {
- config.with_token(token)
- } else {
- config
- };
-
- config
- }
- }};
- }
-
- #[tokio::test]
- async fn s3_test() {
- let config = maybe_skip_integration!();
- let integration = config.build().unwrap();
-
- put_get_delete_list(&integration).await;
- list_uses_directories_correctly(&integration).await;
- list_with_delimiter(&integration).await;
- rename_and_copy(&integration).await;
- stream_get(&integration).await;
- }
-
- #[tokio::test]
- async fn s3_test_get_nonexistent_location() {
- let config = maybe_skip_integration!();
- let integration = config.build().unwrap();
-
- let location = Path::from_iter([NON_EXISTENT_NAME]);
-
- let err = get_nonexistent_object(&integration, Some(location))
- .await
- .unwrap_err();
- if let ObjectStoreError::NotFound { path, source } = err {
- let source_variant = source.downcast_ref::<rusoto_core::RusotoError<_>>();
- assert!(
- matches!(
- source_variant,
- Some(rusoto_core::RusotoError::Service(
- rusoto_s3::GetObjectError::NoSuchKey(_)
- )),
- ),
- "got: {:?}",
- source_variant
- );
- assert_eq!(path, NON_EXISTENT_NAME);
- } else {
- panic!("unexpected error type: {:?}", err);
- }
- }
-
- #[tokio::test]
- async fn s3_test_get_nonexistent_bucket() {
- let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME);
- let integration = config.build().unwrap();
-
- let location = Path::from_iter([NON_EXISTENT_NAME]);
-
- let err = integration.get(&location).await.unwrap_err().to_string();
- assert!(
- err.contains("The specified bucket does not exist"),
- "{}",
- err
- )
- }
-
- #[tokio::test]
- async fn s3_test_put_nonexistent_bucket() {
- let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME);
-
- let integration = config.build().unwrap();
-
- let location = Path::from_iter([NON_EXISTENT_NAME]);
- let data = Bytes::from("arbitrary data");
-
- let err = integration
- .put(&location, data)
- .await
- .unwrap_err()
- .to_string();
-
- assert!(
- err.contains("The specified bucket does not exist")
- && err.contains("Unable to PUT data"),
- "{}",
- err
- )
- }
-
- #[tokio::test]
- async fn s3_test_delete_nonexistent_location() {
- let config = maybe_skip_integration!();
- let integration = config.build().unwrap();
-
- let location = Path::from_iter([NON_EXISTENT_NAME]);
-
- integration.delete(&location).await.unwrap();
- }
-
- #[tokio::test]
- async fn s3_test_delete_nonexistent_bucket() {
- let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME);
- let integration = config.build().unwrap();
-
- let location = Path::from_iter([NON_EXISTENT_NAME]);
-
- let err = integration.delete(&location).await.unwrap_err().to_string();
- assert!(
- err.contains("The specified bucket does not exist")
- && err.contains("Unable to DELETE data"),
- "{}",
- err
- )
- }
-}
diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs
new file mode 100644
index 000000000..36ba9ad12
--- /dev/null
+++ b/object_store/src/aws/client.rs
@@ -0,0 +1,483 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider};
+use crate::client::pagination::stream_paginated;
+use crate::client::retry::RetryExt;
+use crate::multipart::UploadPart;
+use crate::path::DELIMITER;
+use crate::util::{format_http_range, format_prefix};
+use crate::{
+ BoxStream, ListResult, MultipartId, ObjectMeta, Path, Result, RetryConfig, StreamExt,
+};
+use bytes::{Buf, Bytes};
+use chrono::{DateTime, Utc};
+use percent_encoding::{utf8_percent_encode, AsciiSet, PercentEncode, NON_ALPHANUMERIC};
+use reqwest::{Client as ReqwestClient, Method, Response, StatusCode};
+use serde::{Deserialize, Serialize};
+use snafu::{ResultExt, Snafu};
+use std::ops::Range;
+use std::sync::Arc;
+
+// http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
+//
+// Do not URI-encode any of the unreserved characters that RFC 3986 defines:
+// A-Z, a-z, 0-9, hyphen ( - ), underscore ( _ ), period ( . ), and tilde ( ~ ).
+const STRICT_ENCODE_SET: AsciiSet = NON_ALPHANUMERIC
+ .remove(b'-')
+ .remove(b'.')
+ .remove(b'_')
+ .remove(b'~');
+
+/// This struct is used to maintain the URI path encoding
+const STRICT_PATH_ENCODE_SET: AsciiSet = STRICT_ENCODE_SET.remove(b'/');
+
+/// A specialized `Error` for object store-related errors
+#[derive(Debug, Snafu)]
+#[allow(missing_docs)]
+pub(crate) enum Error {
+ #[snafu(display("Error performing get request {}: {}", path, source))]
+ GetRequest {
+ source: reqwest::Error,
+ path: String,
+ },
+
+ #[snafu(display("Error performing put request {}: {}", path, source))]
+ PutRequest {
+ source: reqwest::Error,
+ path: String,
+ },
+
+ #[snafu(display("Error performing delete request {}: {}", path, source))]
+ DeleteRequest {
+ source: reqwest::Error,
+ path: String,
+ },
+
+ #[snafu(display("Error performing copy request {}: {}", path, source))]
+ CopyRequest {
+ source: reqwest::Error,
+ path: String,
+ },
+
+ #[snafu(display("Error performing list request: {}", source))]
+ ListRequest { source: reqwest::Error },
+
+ #[snafu(display("Error performing create multipart request: {}", source))]
+ CreateMultipartRequest { source: reqwest::Error },
+
+ #[snafu(display("Error performing complete multipart request: {}", source))]
+ CompleteMultipartRequest { source: reqwest::Error },
+
+ #[snafu(display("Got invalid list response: {}", source))]
+ InvalidListResponse { source: quick_xml::de::DeError },
+
+ #[snafu(display("Got invalid multipart response: {}", source))]
+ InvalidMultipartResponse { source: quick_xml::de::DeError },
+}
+
+impl From<Error> for crate::Error {
+ fn from(err: Error) -> Self {
+ match err {
+ Error::GetRequest { source, path }
+ | Error::DeleteRequest { source, path }
+ | Error::CopyRequest { source, path }
+ | Error::PutRequest { source, path }
+ if matches!(source.status(), Some(StatusCode::NOT_FOUND)) =>
+ {
+ Self::NotFound {
+ path,
+ source: Box::new(source),
+ }
+ }
+ _ => Self::Generic {
+ store: "S3",
+ source: Box::new(err),
+ },
+ }
+ }
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+pub struct ListResponse {
+ #[serde(default)]
+ pub contents: Vec<ListContents>,
+ #[serde(default)]
+ pub common_prefixes: Vec<ListPrefix>,
+ #[serde(default)]
+ pub next_continuation_token: Option<String>,
+}
+
+impl TryFrom<ListResponse> for ListResult {
+ type Error = crate::Error;
+
+ fn try_from(value: ListResponse) -> Result<Self> {
+ let common_prefixes = value
+ .common_prefixes
+ .into_iter()
+ .map(|x| Ok(Path::parse(&x.prefix)?))
+ .collect::<Result<_>>()?;
+
+ let objects = value
+ .contents
+ .into_iter()
+ .map(TryFrom::try_from)
+ .collect::<Result<_>>()?;
+
+ Ok(Self {
+ common_prefixes,
+ objects,
+ })
+ }
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+pub struct ListPrefix {
+ pub prefix: String,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+pub struct ListContents {
+ pub key: String,
+ pub size: usize,
+ pub last_modified: DateTime<Utc>,
+}
+
+impl TryFrom<ListContents> for ObjectMeta {
+ type Error = crate::Error;
+
+ fn try_from(value: ListContents) -> Result<Self> {
+ Ok(Self {
+ location: Path::parse(value.key)?,
+ last_modified: value.last_modified,
+ size: value.size,
+ })
+ }
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+struct InitiateMultipart {
+ upload_id: String,
+}
+
+#[derive(Debug, Serialize)]
+#[serde(rename_all = "PascalCase", rename = "CompleteMultipartUpload")]
+struct CompleteMultipart {
+ part: Vec<MultipartPart>,
+}
+
+#[derive(Debug, Serialize)]
+struct MultipartPart {
+ #[serde(rename = "$unflatten=ETag")]
+ e_tag: String,
+ #[serde(rename = "$unflatten=PartNumber")]
+ part_number: usize,
+}
+
+#[derive(Debug)]
+pub struct S3Config {
+ pub region: String,
+ pub endpoint: String,
+ pub bucket: String,
+ pub credentials: CredentialProvider,
+ pub retry_config: RetryConfig,
+ pub allow_http: bool,
+}
+
+impl S3Config {
+ fn path_url(&self, path: &Path) -> String {
+ format!("{}/{}/{}", self.endpoint, self.bucket, encode_path(path))
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct S3Client {
+ config: S3Config,
+ client: ReqwestClient,
+}
+
+impl S3Client {
+ pub fn new(config: S3Config) -> Self {
+ let client = reqwest::ClientBuilder::new()
+ .https_only(!config.allow_http)
+ .build()
+ .unwrap();
+
+ Self { config, client }
+ }
+
+ /// Returns the config
+ pub fn config(&self) -> &S3Config {
+ &self.config
+ }
+
+ async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
+ self.config.credentials.get_credential().await
+ }
+
+ /// Make an S3 GET request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetObject.html>
+ pub async fn get_request(
+ &self,
+ path: &Path,
+ range: Option<Range<usize>>,
+ head: bool,
+ ) -> Result<Response> {
+ use reqwest::header::RANGE;
+
+ let credential = self.get_credential().await?;
+ let url = self.config.path_url(path);
+ let method = match head {
+ true => Method::HEAD,
+ false => Method::GET,
+ };
+
+ let mut builder = self.client.request(method, url);
+
+ if let Some(range) = range {
+ builder = builder.header(RANGE, format_http_range(range));
+ }
+
+ let response = builder
+ .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3")
+ .send_retry(&self.config.retry_config)
+ .await
+ .context(GetRequestSnafu {
+ path: path.as_ref(),
+ })?
+ .error_for_status()
+ .context(GetRequestSnafu {
+ path: path.as_ref(),
+ })?;
+
+ Ok(response)
+ }
+
+ /// Make an S3 PUT request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_PutObject.html>
+ pub async fn put_request<T: Serialize + ?Sized + Sync>(
+ &self,
+ path: &Path,
+ bytes: Option<Bytes>,
+ query: &T,
+ ) -> Result<Response> {
+ let credential = self.get_credential().await?;
+ let url = self.config.path_url(path);
+
+ let mut builder = self.client.request(Method::PUT, url);
+ if let Some(bytes) = bytes {
+ builder = builder.body(bytes)
+ }
+
+ let response = builder
+ .query(query)
+ .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3")
+ .send_retry(&self.config.retry_config)
+ .await
+ .context(PutRequestSnafu {
+ path: path.as_ref(),
+ })?
+ .error_for_status()
+ .context(PutRequestSnafu {
+ path: path.as_ref(),
+ })?;
+
+ Ok(response)
+ }
+
+ /// Make an S3 Delete request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_DeleteObject.html>
+ pub async fn delete_request<T: Serialize + ?Sized + Sync>(
+ &self,
+ path: &Path,
+ query: &T,
+ ) -> Result<()> {
+ let credential = self.get_credential().await?;
+ let url = self.config.path_url(path);
+
+ self.client
+ .request(Method::DELETE, url)
+ .query(query)
+ .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3")
+ .send_retry(&self.config.retry_config)
+ .await
+ .context(DeleteRequestSnafu {
+ path: path.as_ref(),
+ })?
+ .error_for_status()
+ .context(DeleteRequestSnafu {
+ path: path.as_ref(),
+ })?;
+
+ Ok(())
+ }
+
+ /// Make an S3 Copy request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_CopyObject.html>
+ pub async fn copy_request(&self, from: &Path, to: &Path) -> Result<()> {
+ let credential = self.get_credential().await?;
+ let url = self.config.path_url(to);
+ let source = format!("{}/{}", self.config.bucket, encode_path(from));
+
+ self.client
+ .request(Method::PUT, url)
+ .header("x-amz-copy-source", source)
+ .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3")
+ .send_retry(&self.config.retry_config)
+ .await
+ .context(CopyRequestSnafu {
+ path: from.as_ref(),
+ })?
+ .error_for_status()
+ .context(CopyRequestSnafu {
+ path: from.as_ref(),
+ })?;
+
+ Ok(())
+ }
+
+ /// Make an S3 List request <https://docs.aws.amazon.com/AmazonS3/latest/API/API_ListObjectsV2.html>
+ async fn list_request(
+ &self,
+ prefix: Option<&str>,
+ delimiter: bool,
+ token: Option<&str>,
+ ) -> Result<(ListResult, Option<String>)> {
+ let credential = self.get_credential().await?;
+ let url = format!("{}/{}", self.config.endpoint, self.config.bucket);
+
+ let mut query = Vec::with_capacity(4);
+
+ // Note: the order of these matters to ensure the generated URL is canonical
+ if let Some(token) = token {
+ query.push(("continuation-token", token))
+ }
+
+ if delimiter {
+ query.push(("delimiter", DELIMITER))
+ }
+
+ query.push(("list-type", "2"));
+
+ if let Some(prefix) = prefix {
+ query.push(("prefix", prefix))
+ }
+
+ let response = self
+ .client
+ .request(Method::GET, &url)
+ .query(&query)
+ .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3")
+ .send_retry(&self.config.retry_config)
+ .await
+ .context(ListRequestSnafu)?
+ .error_for_status()
+ .context(ListRequestSnafu)?
+ .bytes()
+ .await
+ .context(ListRequestSnafu)?;
+
+ let mut response: ListResponse = quick_xml::de::from_reader(response.reader())
+ .context(InvalidListResponseSnafu)?;
+ let token = response.next_continuation_token.take();
+
+ Ok((response.try_into()?, token))
+ }
+
+ /// Perform a list operation automatically handling pagination
+ pub fn list_paginated(
+ &self,
+ prefix: Option<&Path>,
+ delimiter: bool,
+ ) -> BoxStream<'_, Result<ListResult>> {
+ let prefix = format_prefix(prefix);
+ stream_paginated(prefix, move |prefix, token| async move {
+ let (r, next_token) = self
+ .list_request(prefix.as_deref(), delimiter, token.as_deref())
+ .await?;
+ Ok((r, prefix, next_token))
+ })
+ .boxed()
+ }
+
+ pub async fn create_multipart(&self, location: &Path) -> Result<MultipartId> {
+ let credential = self.get_credential().await?;
+ let url = format!(
+ "{}/{}/{}?uploads",
+ self.config.endpoint,
+ self.config.bucket,
+ encode_path(location)
+ );
+
+ let response = self
+ .client
+ .request(Method::POST, url)
+ .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3")
+ .send_retry(&self.config.retry_config)
+ .await
+ .context(CreateMultipartRequestSnafu)?
+ .error_for_status()
+ .context(CreateMultipartRequestSnafu)?
+ .bytes()
+ .await
+ .context(CreateMultipartRequestSnafu)?;
+
+ let response: InitiateMultipart = quick_xml::de::from_reader(response.reader())
+ .context(InvalidMultipartResponseSnafu)?;
+
+ Ok(response.upload_id)
+ }
+
+ pub async fn complete_multipart(
+ &self,
+ location: &Path,
+ upload_id: &str,
+ parts: Vec<UploadPart>,
+ ) -> Result<()> {
+ let parts = parts
+ .into_iter()
+ .enumerate()
+ .map(|(part_idx, part)| MultipartPart {
+ e_tag: part.content_id,
+ part_number: part_idx + 1,
+ })
+ .collect();
+
+ let request = CompleteMultipart { part: parts };
+ let body = quick_xml::se::to_string(&request).unwrap();
+
+ let credential = self.get_credential().await?;
+ let url = self.config.path_url(location);
+
+ self.client
+ .request(Method::POST, url)
+ .query(&[("uploadId", upload_id)])
+ .body(body)
+ .with_aws_sigv4(credential.as_ref(), &self.config.region, "s3")
+ .send_retry(&self.config.retry_config)
+ .await
+ .context(CompleteMultipartRequestSnafu)?
+ .error_for_status()
+ .context(CompleteMultipartRequestSnafu)?;
+
+ Ok(())
+ }
+}
+
+fn encode_path(path: &Path) -> PercentEncode<'_> {
+ utf8_percent_encode(path.as_ref(), &STRICT_PATH_ENCODE_SET)
+}
diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs
new file mode 100644
index 000000000..b75005975
--- /dev/null
+++ b/object_store/src/aws/credential.rs
@@ -0,0 +1,590 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::client::retry::RetryExt;
+use crate::client::token::{TemporaryToken, TokenCache};
+use crate::{Result, RetryConfig};
+use bytes::Buf;
+use chrono::{DateTime, Utc};
+use futures::TryFutureExt;
+use reqwest::header::{HeaderMap, HeaderValue};
+use reqwest::{Client, Method, Request, RequestBuilder};
+use serde::Deserialize;
+use std::collections::BTreeMap;
+use std::sync::Arc;
+use std::time::Instant;
+
+type StdError = Box<dyn std::error::Error + Send + Sync>;
+
+/// SHA256 hash of empty string
+static EMPTY_SHA256_HASH: &str =
+ "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
+
+#[derive(Debug)]
+pub struct AwsCredential {
+ pub key_id: String,
+ pub secret_key: String,
+ pub token: Option<String>,
+}
+
+impl AwsCredential {
+ /// Signs a string
+ ///
+ /// <https://docs.aws.amazon.com/general/latest/gr/sigv4-calculate-signature.html>
+ fn sign(
+ &self,
+ to_sign: &str,
+ date: DateTime<Utc>,
+ region: &str,
+ service: &str,
+ ) -> String {
+ let date_string = date.format("%Y%m%d").to_string();
+ let date_hmac = hmac_sha256(format!("AWS4{}", self.secret_key), date_string);
+ let region_hmac = hmac_sha256(date_hmac, region);
+ let service_hmac = hmac_sha256(region_hmac, service);
+ let signing_hmac = hmac_sha256(service_hmac, b"aws4_request");
+ hex_encode(hmac_sha256(signing_hmac, to_sign).as_ref())
+ }
+}
+
+struct RequestSigner<'a> {
+ date: DateTime<Utc>,
+ credential: &'a AwsCredential,
+ service: &'a str,
+ region: &'a str,
+}
+
+const DATE_HEADER: &str = "x-amz-date";
+const HASH_HEADER: &str = "x-amz-content-sha256";
+const TOKEN_HEADER: &str = "x-amz-security-token";
+const AUTH_HEADER: &str = "authorization";
+
+const ALL_HEADERS: &[&str; 4] = &[DATE_HEADER, HASH_HEADER, TOKEN_HEADER, AUTH_HEADER];
+
+impl<'a> RequestSigner<'a> {
+ fn sign(&self, request: &mut Request) {
+ if let Some(ref token) = self.credential.token {
+ let token_val = HeaderValue::from_str(token).unwrap();
+ request.headers_mut().insert(TOKEN_HEADER, token_val);
+ }
+
+ let host_val = HeaderValue::from_str(
+ &request.url()[url::Position::BeforeHost..url::Position::AfterPort],
+ )
+ .unwrap();
+ request.headers_mut().insert("host", host_val);
+
+ let date_str = self.date.format("%Y%m%dT%H%M%SZ").to_string();
+ let date_val = HeaderValue::from_str(&date_str).unwrap();
+ request.headers_mut().insert(DATE_HEADER, date_val);
+
+ let digest = match request.body() {
+ None => EMPTY_SHA256_HASH.to_string(),
+ Some(body) => hex_digest(body.as_bytes().unwrap()),
+ };
+
+ let header_digest = HeaderValue::from_str(&digest).unwrap();
+ request.headers_mut().insert(HASH_HEADER, header_digest);
+
+ let (signed_headers, canonical_headers) = canonicalize_headers(request.headers());
+
+ // https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html
+ let canonical_request = format!(
+ "{}\n{}\n{}\n{}\n{}\n{}",
+ request.method().as_str(),
+ request.url().path(), // S3 doesn't percent encode this like other services
+ request.url().query().unwrap_or(""), // This assumes the query pairs are in order
+ canonical_headers,
+ signed_headers,
+ digest
+ );
+
+ let hashed_canonical_request = hex_digest(canonical_request.as_bytes());
+ let scope = format!(
+ "{}/{}/{}/aws4_request",
+ self.date.format("%Y%m%d"),
+ self.region,
+ self.service
+ );
+
+ let string_to_sign = format!(
+ "AWS4-HMAC-SHA256\n{}\n{}\n{}",
+ self.date.format("%Y%m%dT%H%M%SZ"),
+ scope,
+ hashed_canonical_request
+ );
+
+ // sign the string
+ let signature =
+ self.credential
+ .sign(&string_to_sign, self.date, self.region, self.service);
+
+ // build the actual auth header
+ let authorisation = format!(
+ "AWS4-HMAC-SHA256 Credential={}/{}, SignedHeaders={}, Signature={}",
+ self.credential.key_id, scope, signed_headers, signature
+ );
+
+ let authorization_val = HeaderValue::from_str(&authorisation).unwrap();
+ request.headers_mut().insert(AUTH_HEADER, authorization_val);
+ }
+}
+
+pub trait CredentialExt {
+ /// Sign a request <https://docs.aws.amazon.com/general/latest/gr/sigv4_signing.html>
+ fn with_aws_sigv4(
+ self,
+ credential: &AwsCredential,
+ region: &str,
+ service: &str,
+ ) -> Self;
+}
+
+impl CredentialExt for RequestBuilder {
+ fn with_aws_sigv4(
+ mut self,
+ credential: &AwsCredential,
+ region: &str,
+ service: &str,
+ ) -> Self {
+ // Hack around lack of access to underlying request
+ // https://github.com/seanmonstar/reqwest/issues/1212
+ let mut request = self
+ .try_clone()
+ .expect("not stream")
+ .build()
+ .expect("request valid");
+
+ let date = Utc::now();
+ let signer = RequestSigner {
+ date,
+ credential,
+ service,
+ region,
+ };
+
+ signer.sign(&mut request);
+
+ for header in ALL_HEADERS {
+ if let Some(val) = request.headers_mut().remove(*header) {
+ self = self.header(*header, val)
+ }
+ }
+ self
+ }
+}
+
+fn hmac_sha256(secret: impl AsRef<[u8]>, bytes: impl AsRef<[u8]>) -> ring::hmac::Tag {
+ let key = ring::hmac::Key::new(ring::hmac::HMAC_SHA256, secret.as_ref());
+ ring::hmac::sign(&key, bytes.as_ref())
+}
+
+/// Computes the SHA256 digest of `body` returned as a hex encoded string
+fn hex_digest(bytes: &[u8]) -> String {
+ let digest = ring::digest::digest(&ring::digest::SHA256, bytes);
+ hex_encode(digest.as_ref())
+}
+
+/// Returns `bytes` as a lower-case hex encoded string
+fn hex_encode(bytes: &[u8]) -> String {
+ use std::fmt::Write;
+ let mut out = String::with_capacity(bytes.len() * 2);
+ for byte in bytes {
+ // String writing is infallible
+ let _ = write!(out, "{:02x}", byte);
+ }
+ out
+}
+
+/// Canonicalizes headers into the AWS Canonical Form.
+///
+/// <https://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html>
+fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) {
+ let mut headers = BTreeMap::<&str, Vec<&str>>::new();
+ let mut value_count = 0;
+ let mut value_bytes = 0;
+ let mut key_bytes = 0;
+
+ for (key, value) in header_map {
+ let key = key.as_str();
+ if ["authorization", "content-length", "user-agent"].contains(&key) {
+ continue;
+ }
+
+ let value = std::str::from_utf8(value.as_bytes()).unwrap();
+ key_bytes += key.len();
+ value_bytes += value.len();
+ value_count += 1;
+ headers.entry(key).or_default().push(value);
+ }
+
+ let mut signed_headers = String::with_capacity(key_bytes + headers.len());
+ let mut canonical_headers =
+ String::with_capacity(key_bytes + value_bytes + headers.len() + value_count);
+
+ for (header_idx, (name, values)) in headers.into_iter().enumerate() {
+ if header_idx != 0 {
+ signed_headers.push(';');
+ }
+
+ signed_headers.push_str(name);
+ canonical_headers.push_str(name);
+ canonical_headers.push(':');
+ for (value_idx, value) in values.into_iter().enumerate() {
+ if value_idx != 0 {
+ canonical_headers.push(',');
+ }
+ canonical_headers.push_str(value.trim());
+ }
+ canonical_headers.push('\n');
+ }
+
+ (signed_headers, canonical_headers)
+}
+
+/// Provides credentials for use when signing requests
+#[derive(Debug)]
+pub enum CredentialProvider {
+ Static(StaticCredentialProvider),
+ Instance(InstanceCredentialProvider),
+ WebIdentity(WebIdentityProvider),
+}
+
+impl CredentialProvider {
+ pub async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
+ match self {
+ Self::Static(s) => Ok(Arc::clone(&s.credential)),
+ Self::Instance(c) => c.get_credential().await,
+ Self::WebIdentity(c) => c.get_credential().await,
+ }
+ }
+}
+
+/// A static set of credentials
+#[derive(Debug)]
+pub struct StaticCredentialProvider {
+ pub credential: Arc<AwsCredential>,
+}
+
+/// Credentials sourced from the instance metadata service
+///
+/// <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html>
+#[derive(Debug)]
+pub struct InstanceCredentialProvider {
+ pub cache: TokenCache<Arc<AwsCredential>>,
+ pub client: Client,
+ pub retry_config: RetryConfig,
+}
+
+impl InstanceCredentialProvider {
+ async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
+ self.cache
+ .get_or_insert_with(|| {
+ const METADATA_ENDPOINT: &str = "http://169.254.169.254";
+ instance_creds(&self.client, &self.retry_config, METADATA_ENDPOINT)
+ .map_err(|source| crate::Error::Generic {
+ store: "S3",
+ source,
+ })
+ })
+ .await
+ }
+}
+
+/// Credentials sourced using AssumeRoleWithWebIdentity
+///
+/// <https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts-technical-overview.html>
+#[derive(Debug)]
+pub struct WebIdentityProvider {
+ pub cache: TokenCache<Arc<AwsCredential>>,
+ pub token: String,
+ pub role_arn: String,
+ pub session_name: String,
+ pub endpoint: String,
+ pub client: Client,
+ pub retry_config: RetryConfig,
+}
+
+impl WebIdentityProvider {
+ async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
+ self.cache
+ .get_or_insert_with(|| {
+ web_identity(
+ &self.client,
+ &self.retry_config,
+ &self.token,
+ &self.role_arn,
+ &self.session_name,
+ &self.endpoint,
+ )
+ .map_err(|source| crate::Error::Generic {
+ store: "S3",
+ source,
+ })
+ })
+ .await
+ }
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+struct InstanceCredentials {
+ access_key_id: String,
+ secret_access_key: String,
+ token: String,
+ expiration: DateTime<Utc>,
+}
+
+impl From<InstanceCredentials> for AwsCredential {
+ fn from(s: InstanceCredentials) -> Self {
+ Self {
+ key_id: s.access_key_id,
+ secret_key: s.secret_access_key,
+ token: Some(s.token),
+ }
+ }
+}
+
+/// <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/iam-roles-for-amazon-ec2.html#instance-metadata-security-credentials>
+async fn instance_creds(
+ client: &Client,
+ retry_config: &RetryConfig,
+ endpoint: &str,
+) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
+ const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials";
+ const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";
+
+ let token_url = format!("{}/latest/api/token", endpoint);
+ let token = client
+ .request(Method::PUT, token_url)
+ .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL
+ .send_retry(retry_config)
+ .await?
+ .text()
+ .await?;
+
+ let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH);
+ let role = client
+ .request(Method::GET, role_url)
+ .header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
+ .send_retry(retry_config)
+ .await?
+ .text()
+ .await?;
+
+ let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role);
+ let creds: InstanceCredentials = client
+ .request(Method::GET, creds_url)
+ .header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
+ .send_retry(retry_config)
+ .await?
+ .json()
+ .await?;
+
+ let now = Utc::now();
+ let ttl = (creds.expiration - now).to_std().unwrap_or_default();
+ Ok(TemporaryToken {
+ token: Arc::new(creds.into()),
+ expiry: Instant::now() + ttl,
+ })
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+struct AssumeRoleResponse {
+ assume_role_with_web_identity_result: AssumeRoleResult,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+struct AssumeRoleResult {
+ credentials: AssumeRoleCredentials,
+}
+
+#[derive(Debug, Deserialize)]
+#[serde(rename_all = "PascalCase")]
+struct AssumeRoleCredentials {
+ session_token: String,
+ secret_access_key: String,
+ access_key_id: String,
+ expiration: DateTime<Utc>,
+}
+
+impl From<AssumeRoleCredentials> for AwsCredential {
+ fn from(s: AssumeRoleCredentials) -> Self {
+ Self {
+ key_id: s.access_key_id,
+ secret_key: s.secret_access_key,
+ token: Some(s.session_token),
+ }
+ }
+}
+
+/// <https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts-technical-overview.html>
+async fn web_identity(
+ client: &Client,
+ retry_config: &RetryConfig,
+ token: &str,
+ role_arn: &str,
+ session_name: &str,
+ endpoint: &str,
+) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
+ let bytes = client
+ .request(Method::POST, endpoint)
+ .query(&[
+ ("Action", "AssumeRoleWithWebIdentity"),
+ ("DurationSeconds", "3600"),
+ ("RoleArn", role_arn),
+ ("RoleSessionName", session_name),
+ ("Version", "2011-06-15"),
+ ("WebIdentityToken", token),
+ ])
+ .send_retry(retry_config)
+ .await?
+ .bytes()
+ .await?;
+
+ let resp: AssumeRoleResponse = quick_xml::de::from_reader(bytes.reader())
+ .map_err(|e| format!("Invalid AssumeRoleWithWebIdentity response: {}", e))?;
+
+ let creds = resp.assume_role_with_web_identity_result.credentials;
+ let now = Utc::now();
+ let ttl = (creds.expiration - now).to_std().unwrap_or_default();
+
+ Ok(TemporaryToken {
+ token: Arc::new(creds.into()),
+ expiry: Instant::now() + ttl,
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use reqwest::{Client, Method};
+ use std::env;
+
+ // Test generated using https://docs.aws.amazon.com/general/latest/gr/sigv4-signed-request-examples.html
+ #[test]
+ fn test_sign() {
+ let client = Client::new();
+
+ // Test credentials from https://docs.aws.amazon.com/AmazonS3/latest/userguide/RESTAuthentication.html
+ let credential = AwsCredential {
+ key_id: "AKIAIOSFODNN7EXAMPLE".to_string(),
+ secret_key: "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY".to_string(),
+ token: None,
+ };
+
+ // method = 'GET'
+ // service = 'ec2'
+ // host = 'ec2.amazonaws.com'
+ // region = 'us-east-1'
+ // endpoint = 'https://ec2.amazonaws.com'
+ // request_parameters = ''
+ let date = DateTime::parse_from_rfc3339("2022-08-06T18:01:34Z")
+ .unwrap()
+ .with_timezone(&Utc);
+
+ let mut request = client
+ .request(Method::GET, "https://ec2.amazon.com/")
+ .build()
+ .unwrap();
+
+ let signer = RequestSigner {
+ date,
+ credential: &credential,
+ service: "ec2",
+ region: "us-east-1",
+ };
+
+ signer.sign(&mut request);
+ assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20220806/us-east-1/ec2/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=a3c787a7ed37f7fdfbfd2d7056a3d7c9d85e6d52a2bfbec73793c0be6e7862d4")
+ }
+
+ #[test]
+ fn test_sign_port() {
+ let client = Client::new();
+
+ let credential = AwsCredential {
+ key_id: "H20ABqCkLZID4rLe".to_string(),
+ secret_key: "jMqRDgxSsBqqznfmddGdu1TmmZOJQxdM".to_string(),
+ token: None,
+ };
+
+ let date = DateTime::parse_from_rfc3339("2022-08-09T13:05:25Z")
+ .unwrap()
+ .with_timezone(&Utc);
+
+ let mut request = client
+ .request(Method::GET, "http://localhost:9000/tsm-schemas")
+ .query(&[
+ ("delimiter", "/"),
+ ("encoding-type", "url"),
+ ("list-type", "2"),
+ ("prefix", ""),
+ ])
+ .build()
+ .unwrap();
+
+ let signer = RequestSigner {
+ date,
+ credential: &credential,
+ service: "s3",
+ region: "us-east-1",
+ };
+
+ signer.sign(&mut request);
+ assert_eq!(request.headers().get(AUTH_HEADER).unwrap(), "AWS4-HMAC-SHA256 Credential=H20ABqCkLZID4rLe/20220809/us-east-1/s3/aws4_request, SignedHeaders=host;x-amz-content-sha256;x-amz-date, Signature=9ebf2f92872066c99ac94e573b4e1b80f4dbb8a32b1e8e23178318746e7d1b4d")
+ }
+
+ #[tokio::test]
+ async fn test_instance_metadata() {
+ if env::var("TEST_INTEGRATION").is_err() {
+ eprintln!("skipping AWS integration test");
+ }
+
+ // For example https://github.com/aws/amazon-ec2-metadata-mock
+ let endpoint = env::var("EC2_METADATA_ENDPOINT").unwrap();
+ let client = Client::new();
+ let retry_config = RetryConfig::default();
+
+ // Verify only allows IMDSv2
+ let resp = client
+ .request(Method::GET, format!("{}/latest/meta-data/ami-id", endpoint))
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(
+ resp.status(),
+ reqwest::StatusCode::UNAUTHORIZED,
+ "Ensure metadata endpoint is set to only allow IMDSv2"
+ );
+
+ let creds = instance_creds(&client, &retry_config, &endpoint)
+ .await
+ .unwrap();
+
+ let id = &creds.token.key_id;
+ let secret = &creds.token.secret_key;
+ let token = creds.token.token.as_ref().unwrap();
+
+ assert!(!id.is_empty());
+ assert!(!secret.is_empty());
+ assert!(!token.is_empty())
+ }
+}
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
new file mode 100644
index 000000000..06d20ccc9
--- /dev/null
+++ b/object_store/src/aws/mod.rs
@@ -0,0 +1,646 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! An object store implementation for S3
+//!
+//! ## Multi-part uploads
+//!
+//! Multi-part uploads can be initiated with the [ObjectStore::put_multipart] method.
+//! Data passed to the writer is automatically buffered to meet the minimum size
+//! requirements for a part. Multiple parts are uploaded concurrently.
+//!
+//! If the writer fails for any reason, you may have parts uploaded to AWS but not
+//! used that you may be charged for. Use the [ObjectStore::abort_multipart] method
+//! to abort the upload and drop those unneeded parts. In addition, you may wish to
+//! consider implementing [automatic cleanup] of unused parts that are older than one
+//! week.
+//!
+//! [automatic cleanup]: https://aws.amazon.com/blogs/aws/s3-lifecycle-management-update-support-for-multipart-uploads-and-delete-markers/
+
+use async_trait::async_trait;
+use bytes::Bytes;
+use chrono::{DateTime, Utc};
+use futures::stream::BoxStream;
+use futures::TryStreamExt;
+use reqwest::Client;
+use snafu::{OptionExt, ResultExt, Snafu};
+use std::collections::BTreeSet;
+use std::ops::Range;
+use std::sync::Arc;
+use tokio::io::AsyncWrite;
+use tracing::info;
+
+use crate::aws::client::{S3Client, S3Config};
+use crate::aws::credential::{
+ AwsCredential, CredentialProvider, InstanceCredentialProvider,
+ StaticCredentialProvider, WebIdentityProvider,
+};
+use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart};
+use crate::{
+ GetResult, ListResult, MultipartId, ObjectMeta, ObjectStore, Path, Result,
+ RetryConfig, StreamExt,
+};
+
+mod client;
+mod credential;
+
+/// A specialized `Error` for object store-related errors
+#[derive(Debug, Snafu)]
+#[allow(missing_docs)]
+enum Error {
+ #[snafu(display("Last-Modified Header missing from response"))]
+ MissingLastModified,
+
+ #[snafu(display("Content-Length Header missing from response"))]
+ MissingContentLength,
+
+ #[snafu(display("Invalid last modified '{}': {}", last_modified, source))]
+ InvalidLastModified {
+ last_modified: String,
+ source: chrono::ParseError,
+ },
+
+ #[snafu(display("Invalid content length '{}': {}", content_length, source))]
+ InvalidContentLength {
+ content_length: String,
+ source: std::num::ParseIntError,
+ },
+
+ #[snafu(display("Missing region"))]
+ MissingRegion,
+
+ #[snafu(display("Missing bucket name"))]
+ MissingBucketName,
+
+ #[snafu(display("Missing AccessKeyId"))]
+ MissingAccessKeyId,
+
+ #[snafu(display("Missing SecretAccessKey"))]
+ MissingSecretAccessKey,
+
+ #[snafu(display("ETag Header missing from response"))]
+ MissingEtag,
+
+ #[snafu(display("Received header containing non-ASCII data"))]
+ BadHeader { source: reqwest::header::ToStrError },
+
+ #[snafu(display("Error reading token file: {}", source))]
+ ReadTokenFile { source: std::io::Error },
+}
+
+impl From<Error> for super::Error {
+ fn from(err: Error) -> Self {
+ Self::Generic {
+ store: "S3",
+ source: Box::new(err),
+ }
+ }
+}
+
+/// Interface for [Amazon S3](https://aws.amazon.com/s3/).
+#[derive(Debug)]
+pub struct AmazonS3 {
+ client: Arc<S3Client>,
+}
+
+impl std::fmt::Display for AmazonS3 {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ write!(f, "AmazonS3({})", self.client.config().bucket)
+ }
+}
+
+#[async_trait]
+impl ObjectStore for AmazonS3 {
+ async fn put(&self, location: &Path, bytes: Bytes) -> Result<()> {
+ self.client.put_request(location, Some(bytes), &()).await?;
+ Ok(())
+ }
+
+ async fn put_multipart(
+ &self,
+ location: &Path,
+ ) -> Result<(MultipartId, Box<dyn AsyncWrite + Unpin + Send>)> {
+ let id = self.client.create_multipart(location).await?;
+
+ let upload = S3MultiPartUpload {
+ location: location.clone(),
+ upload_id: id.clone(),
+ client: Arc::clone(&self.client),
+ };
+
+ Ok((id, Box::new(CloudMultiPartUpload::new(upload, 8))))
+ }
+
+ async fn abort_multipart(
+ &self,
+ location: &Path,
+ multipart_id: &MultipartId,
+ ) -> Result<()> {
+ self.client
+ .delete_request(location, &[("uploadId", multipart_id)])
+ .await
+ }
+
+ async fn get(&self, location: &Path) -> Result<GetResult> {
+ let response = self.client.get_request(location, None, false).await?;
+ let stream = response
+ .bytes_stream()
+ .map_err(|source| crate::Error::Generic {
+ store: "S3",
+ source: Box::new(source),
+ })
+ .boxed();
+
+ Ok(GetResult::Stream(stream))
+ }
+
+ async fn get_range(&self, location: &Path, range: Range<usize>) -> Result<Bytes> {
+ let bytes = self
+ .client
+ .get_request(location, Some(range), false)
+ .await?
+ .bytes()
+ .await
+ .map_err(|source| client::Error::GetRequest {
+ source,
+ path: location.to_string(),
+ })?;
+ Ok(bytes)
+ }
+
+ async fn head(&self, location: &Path) -> Result<ObjectMeta> {
+ use reqwest::header::{CONTENT_LENGTH, LAST_MODIFIED};
+
+ // Extract meta from headers
+ // https://docs.aws.amazon.com/AmazonS3/latest/API/API_HeadObject.html#API_HeadObject_ResponseSyntax
+ let response = self.client.get_request(location, None, true).await?;
+ let headers = response.headers();
+
+ let last_modified = headers
+ .get(LAST_MODIFIED)
+ .context(MissingLastModifiedSnafu)?;
+
+ let content_length = headers
+ .get(CONTENT_LENGTH)
+ .context(MissingContentLengthSnafu)?;
+
+ let last_modified = last_modified.to_str().context(BadHeaderSnafu)?;
+ let last_modified = DateTime::parse_from_rfc2822(last_modified)
+ .context(InvalidLastModifiedSnafu { last_modified })?
+ .with_timezone(&Utc);
+
+ let content_length = content_length.to_str().context(BadHeaderSnafu)?;
+ let content_length = content_length
+ .parse()
+ .context(InvalidContentLengthSnafu { content_length })?;
+ Ok(ObjectMeta {
+ location: location.clone(),
+ last_modified,
+ size: content_length,
+ })
+ }
+
+ async fn delete(&self, location: &Path) -> Result<()> {
+ self.client.delete_request(location, &()).await
+ }
+
+ async fn list(
+ &self,
+ prefix: Option<&Path>,
+ ) -> Result<BoxStream<'_, Result<ObjectMeta>>> {
+ let stream = self
+ .client
+ .list_paginated(prefix, false)
+ .map_ok(|r| futures::stream::iter(r.objects.into_iter().map(Ok)))
+ .try_flatten()
+ .boxed();
+
+ Ok(stream)
+ }
+
+ async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
+ let mut stream = self.client.list_paginated(prefix, true);
+
+ let mut common_prefixes = BTreeSet::new();
+ let mut objects = Vec::new();
+
+ while let Some(result) = stream.next().await {
+ let response = result?;
+ common_prefixes.extend(response.common_prefixes.into_iter());
+ objects.extend(response.objects.into_iter());
+ }
+
+ Ok(ListResult {
+ common_prefixes: common_prefixes.into_iter().collect(),
+ objects,
+ })
+ }
+
+ async fn copy(&self, from: &Path, to: &Path) -> Result<()> {
+ self.client.copy_request(from, to).await
+ }
+
+ async fn copy_if_not_exists(&self, _source: &Path, _dest: &Path) -> Result<()> {
+ // Will need dynamodb_lock
+ Err(crate::Error::NotImplemented)
+ }
+}
+
+struct S3MultiPartUpload {
+ location: Path,
+ upload_id: String,
+ client: Arc<S3Client>,
+}
+
+#[async_trait]
+impl CloudMultiPartUploadImpl for S3MultiPartUpload {
+ async fn put_multipart_part(
+ &self,
+ buf: Vec<u8>,
+ part_idx: usize,
+ ) -> Result<UploadPart, std::io::Error> {
+ use reqwest::header::ETAG;
+ let part = (part_idx + 1).to_string();
+
+ let response = self
+ .client
+ .put_request(
+ &self.location,
+ Some(buf.into()),
+ &[("partNumber", &part), ("uploadId", &self.upload_id)],
+ )
+ .await?;
+
+ let etag = response
+ .headers()
+ .get(ETAG)
+ .context(MissingEtagSnafu)
+ .map_err(crate::Error::from)?;
+
+ let etag = etag
+ .to_str()
+ .context(BadHeaderSnafu)
+ .map_err(crate::Error::from)?;
+
+ Ok(UploadPart {
+ content_id: etag.to_string(),
+ })
+ }
+
+ async fn complete(
+ &self,
+ completed_parts: Vec<UploadPart>,
+ ) -> Result<(), std::io::Error> {
+ self.client
+ .complete_multipart(&self.location, &self.upload_id, completed_parts)
+ .await?;
+ Ok(())
+ }
+}
+
+/// Configure a connection to Amazon S3 using the specified credentials in
+/// the specified Amazon region and bucket.
+///
+/// # Example
+/// ```
+/// # let REGION = "foo";
+/// # let BUCKET_NAME = "foo";
+/// # let ACCESS_KEY_ID = "foo";
+/// # let SECRET_KEY = "foo";
+/// # use object_store::aws::AmazonS3Builder;
+/// let s3 = AmazonS3Builder::new()
+/// .with_region(REGION)
+/// .with_bucket_name(BUCKET_NAME)
+/// .with_access_key_id(ACCESS_KEY_ID)
+/// .with_secret_access_key(SECRET_KEY)
+/// .build();
+/// ```
+#[derive(Debug, Default)]
+pub struct AmazonS3Builder {
+ access_key_id: Option<String>,
+ secret_access_key: Option<String>,
+ region: Option<String>,
+ bucket_name: Option<String>,
+ endpoint: Option<String>,
+ token: Option<String>,
+ retry_config: RetryConfig,
+ allow_http: bool,
+}
+
+impl AmazonS3Builder {
+ /// Create a new [`AmazonS3Builder`] with default values.
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Set the AWS Access Key (required)
+ pub fn with_access_key_id(mut self, access_key_id: impl Into<String>) -> Self {
+ self.access_key_id = Some(access_key_id.into());
+ self
+ }
+
+ /// Set the AWS Secret Access Key (required)
+ pub fn with_secret_access_key(
+ mut self,
+ secret_access_key: impl Into<String>,
+ ) -> Self {
+ self.secret_access_key = Some(secret_access_key.into());
+ self
+ }
+
+ /// Set the region (e.g. `us-east-1`) (required)
+ pub fn with_region(mut self, region: impl Into<String>) -> Self {
+ self.region = Some(region.into());
+ self
+ }
+
+ /// Set the bucket_name (required)
+ pub fn with_bucket_name(mut self, bucket_name: impl Into<String>) -> Self {
+ self.bucket_name = Some(bucket_name.into());
+ self
+ }
+
+ /// Sets the endpoint for communicating with AWS S3. Default value
+ /// is based on region.
+ ///
+ /// For example, this might be set to `"http://localhost:4566:`
+ /// for testing against a localstack instance.
+ pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
+ self.endpoint = Some(endpoint.into());
+ self
+ }
+
+ /// Set the token to use for requests (passed to underlying provider)
+ pub fn with_token(mut self, token: impl Into<String>) -> Self {
+ self.token = Some(token.into());
+ self
+ }
+
+ /// Sets what protocol is allowed. If `allow_http` is :
+ /// * false (default): Only HTTPS are allowed
+ /// * true: HTTP and HTTPS are allowed
+ pub fn with_allow_http(mut self, allow_http: bool) -> Self {
+ self.allow_http = allow_http;
+ self
+ }
+
+ /// Set the retry configuration
+ pub fn with_retry(mut self, retry_config: RetryConfig) -> Self {
+ self.retry_config = retry_config;
+ self
+ }
+
+ /// Create a [`AmazonS3`] instance from the provided values,
+ /// consuming `self`.
+ pub fn build(self) -> Result<AmazonS3> {
+ let bucket = self.bucket_name.context(MissingBucketNameSnafu)?;
+ let region = self.region.context(MissingRegionSnafu)?;
+
+ let credentials = match (self.access_key_id, self.secret_access_key, self.token) {
+ (Some(key_id), Some(secret_key), token) => {
+ info!("Using Static credential provider");
+ CredentialProvider::Static(StaticCredentialProvider {
+ credential: Arc::new(AwsCredential {
+ key_id,
+ secret_key,
+ token,
+ }),
+ })
+ }
+ (None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()),
+ (Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()),
+ // TODO: Replace with `AmazonS3Builder::credentials_from_env`
+ _ => match (
+ std::env::var_os("AWS_WEB_IDENTITY_TOKEN_FILE"),
+ std::env::var("AWS_ROLE_ARN"),
+ ) {
+ (Some(token_file), Ok(role_arn)) => {
+ info!("Using WebIdentity credential provider");
+ let token = std::fs::read_to_string(token_file)
+ .context(ReadTokenFileSnafu)?;
+
+ let session_name = std::env::var("AWS_ROLE_SESSION_NAME")
+ .unwrap_or_else(|_| "WebIdentitySession".to_string());
+
+ let endpoint = format!("https://sts.{}.amazonaws.com", region);
+
+ // Disallow non-HTTPs requests
+ let client = Client::builder().https_only(true).build().unwrap();
+
+ CredentialProvider::WebIdentity(WebIdentityProvider {
+ cache: Default::default(),
+ token,
+ session_name,
+ role_arn,
+ endpoint,
+ client,
+ retry_config: self.retry_config.clone(),
+ })
+ }
+ _ => {
+ info!("Using Instance credential provider");
+
+ // The instance metadata endpoint is access over HTTP
+ let client = Client::builder().https_only(false).build().unwrap();
+
+ CredentialProvider::Instance(InstanceCredentialProvider {
+ cache: Default::default(),
+ client,
+ retry_config: self.retry_config.clone(),
+ })
+ }
+ },
+ };
+
+ let endpoint = self
+ .endpoint
+ .unwrap_or_else(|| format!("https://s3.{}.amazonaws.com", region));
+
+ let config = S3Config {
+ region,
+ endpoint,
+ bucket,
+ credentials,
+ retry_config: self.retry_config,
+ allow_http: self.allow_http,
+ };
+
+ let client = Arc::new(S3Client::new(config));
+
+ Ok(AmazonS3 { client })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::tests::{
+ get_nonexistent_object, list_uses_directories_correctly, list_with_delimiter,
+ put_get_delete_list, rename_and_copy, stream_get,
+ };
+ use bytes::Bytes;
+ use std::env;
+
+ const NON_EXISTENT_NAME: &str = "nonexistentname";
+
+ // Helper macro to skip tests if TEST_INTEGRATION and the AWS
+ // environment variables are not set. Returns a configured
+ // AmazonS3Builder
+ macro_rules! maybe_skip_integration {
+ () => {{
+ dotenv::dotenv().ok();
+
+ let required_vars = [
+ "AWS_DEFAULT_REGION",
+ "OBJECT_STORE_BUCKET",
+ "AWS_ACCESS_KEY_ID",
+ "AWS_SECRET_ACCESS_KEY",
+ ];
+ let unset_vars: Vec<_> = required_vars
+ .iter()
+ .filter_map(|&name| match env::var(name) {
+ Ok(_) => None,
+ Err(_) => Some(name),
+ })
+ .collect();
+ let unset_var_names = unset_vars.join(", ");
+
+ let force = env::var("TEST_INTEGRATION");
+
+ if force.is_ok() && !unset_var_names.is_empty() {
+ panic!(
+ "TEST_INTEGRATION is set, \
+ but variable(s) {} need to be set",
+ unset_var_names
+ );
+ } else if force.is_err() {
+ eprintln!(
+ "skipping AWS integration test - set {}TEST_INTEGRATION to run",
+ if unset_var_names.is_empty() {
+ String::new()
+ } else {
+ format!("{} and ", unset_var_names)
+ }
+ );
+ return;
+ } else {
+ let config = AmazonS3Builder::new()
+ .with_access_key_id(
+ env::var("AWS_ACCESS_KEY_ID")
+ .expect("already checked AWS_ACCESS_KEY_ID"),
+ )
+ .with_secret_access_key(
+ env::var("AWS_SECRET_ACCESS_KEY")
+ .expect("already checked AWS_SECRET_ACCESS_KEY"),
+ )
+ .with_region(
+ env::var("AWS_DEFAULT_REGION")
+ .expect("already checked AWS_DEFAULT_REGION"),
+ )
+ .with_bucket_name(
+ env::var("OBJECT_STORE_BUCKET")
+ .expect("already checked OBJECT_STORE_BUCKET"),
+ )
+ .with_allow_http(true);
+
+ let config = if let Some(endpoint) = env::var("AWS_ENDPOINT").ok() {
+ config.with_endpoint(endpoint)
+ } else {
+ config
+ };
+
+ let config = if let Some(token) = env::var("AWS_SESSION_TOKEN").ok() {
+ config.with_token(token)
+ } else {
+ config
+ };
+
+ config
+ }
+ }};
+ }
+
+ #[tokio::test]
+ async fn s3_test() {
+ let config = maybe_skip_integration!();
+ let integration = config.build().unwrap();
+
+ put_get_delete_list(&integration).await;
+ list_uses_directories_correctly(&integration).await;
+ list_with_delimiter(&integration).await;
+ rename_and_copy(&integration).await;
+ stream_get(&integration).await;
+ }
+
+ #[tokio::test]
+ async fn s3_test_get_nonexistent_location() {
+ let config = maybe_skip_integration!();
+ let integration = config.build().unwrap();
+
+ let location = Path::from_iter([NON_EXISTENT_NAME]);
+
+ let err = get_nonexistent_object(&integration, Some(location))
+ .await
+ .unwrap_err();
+ assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err);
+ }
+
+ #[tokio::test]
+ async fn s3_test_get_nonexistent_bucket() {
+ let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME);
+ let integration = config.build().unwrap();
+
+ let location = Path::from_iter([NON_EXISTENT_NAME]);
+
+ let err = integration.get(&location).await.unwrap_err();
+ assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err);
+ }
+
+ #[tokio::test]
+ async fn s3_test_put_nonexistent_bucket() {
+ let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME);
+
+ let integration = config.build().unwrap();
+
+ let location = Path::from_iter([NON_EXISTENT_NAME]);
+ let data = Bytes::from("arbitrary data");
+
+ let err = integration.put(&location, data).await.unwrap_err();
+ assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err);
+ }
+
+ #[tokio::test]
+ async fn s3_test_delete_nonexistent_location() {
+ let config = maybe_skip_integration!();
+ let integration = config.build().unwrap();
+
+ let location = Path::from_iter([NON_EXISTENT_NAME]);
+
+ integration.delete(&location).await.unwrap();
+ }
+
+ #[tokio::test]
+ async fn s3_test_delete_nonexistent_bucket() {
+ let config = maybe_skip_integration!().with_bucket_name(NON_EXISTENT_NAME);
+ let integration = config.build().unwrap();
+
+ let location = Path::from_iter([NON_EXISTENT_NAME]);
+
+ let err = integration.delete(&location).await.unwrap_err();
+ assert!(matches!(err, crate::Error::NotFound { .. }), "{}", err);
+ }
+}
diff --git a/object_store/src/azure.rs b/object_store/src/azure.rs
index 9987c0370..a9dbc53e2 100644
--- a/object_store/src/azure.rs
+++ b/object_store/src/azure.rs
@@ -49,7 +49,7 @@ use azure_storage_blobs::prelude::{
};
use bytes::Bytes;
use chrono::{TimeZone, Utc};
-use futures::{future::BoxFuture, stream::BoxStream, StreamExt, TryStreamExt};
+use futures::{stream::BoxStream, StreamExt, TryStreamExt};
use snafu::{ResultExt, Snafu};
use std::collections::BTreeSet;
use std::fmt::{Debug, Formatter};
@@ -765,70 +765,47 @@ impl AzureMultiPartUpload {
}
}
+#[async_trait]
impl CloudMultiPartUploadImpl for AzureMultiPartUpload {
- fn put_multipart_part(
+ async fn put_multipart_part(
&self,
buf: Vec<u8>,
part_idx: usize,
- ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> {
- let client = Arc::clone(&self.container_client);
- let location = self.location.clone();
+ ) -> Result<UploadPart, io::Error> {
let block_id = self.get_block_id(part_idx);
- Box::pin(async move {
- client
- .blob_client(location.as_ref())
- .put_block(block_id.clone(), buf)
- .into_future()
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+ self.container_client
+ .blob_client(self.location.as_ref())
+ .put_block(block_id.clone(), buf)
+ .into_future()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
- Ok((
- part_idx,
- UploadPart {
- content_id: block_id,
- },
- ))
+ Ok(UploadPart {
+ content_id: block_id,
})
}
- fn complete(
- &self,
- completed_parts: Vec<Option<UploadPart>>,
- ) -> BoxFuture<'static, Result<(), io::Error>> {
- let parts =
- completed_parts
- .into_iter()
- .enumerate()
- .map(|(part_number, maybe_part)| match maybe_part {
- Some(part) => {
- Ok(azure_storage_blobs::blob::BlobBlockType::Uncommitted(
- azure_storage_blobs::prelude::BlockId::new(part.content_id),
- ))
- }
- None => Err(io::Error::new(
- io::ErrorKind::Other,
- format!("Missing information for upload part {:?}", part_number),
- )),
- });
-
- let client = Arc::clone(&self.container_client);
- let location = self.location.clone();
-
- Box::pin(async move {
- let block_list = azure_storage_blobs::blob::BlockList {
- blocks: parts.collect::<Result<_, io::Error>>()?,
- };
-
- client
- .blob_client(location.as_ref())
- .put_block_list(block_list)
- .into_future()
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+ async fn complete(&self, completed_parts: Vec<UploadPart>) -> Result<(), io::Error> {
+ let blocks = completed_parts
+ .into_iter()
+ .map(|part| {
+ azure_storage_blobs::blob::BlobBlockType::Uncommitted(
+ azure_storage_blobs::prelude::BlockId::new(part.content_id),
+ )
+ })
+ .collect();
- Ok(())
- })
+ let block_list = azure_storage_blobs::blob::BlockList { blocks };
+
+ self.container_client
+ .blob_client(self.location.as_ref())
+ .put_block_list(block_list)
+ .into_future()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ Ok(())
}
}
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index 1166ebe7a..7241002a0 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -18,6 +18,8 @@
//! Generic utilities reqwest based ObjectStore implementations
pub mod backoff;
+#[cfg(feature = "gcp")]
pub mod oauth;
+pub mod pagination;
pub mod retry;
pub mod token;
diff --git a/object_store/src/client/pagination.rs b/object_store/src/client/pagination.rs
new file mode 100644
index 000000000..3ab17fe8b
--- /dev/null
+++ b/object_store/src/client/pagination.rs
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::Result;
+use futures::Stream;
+use std::future::Future;
+
+/// Takes a paginated operation `op` that when called with:
+///
+/// - A state `S`
+/// - An optional next token `Option<String>`
+///
+/// Returns
+///
+/// - A response value `T`
+/// - The next state `S`
+/// - The next continuation token `Option<String>`
+///
+/// And converts it into a `Stream<Result<T>>` which will first call `op(state, None)`, and yield
+/// the returned response `T`. If the returned continuation token was `None` the stream will then
+/// finish, otherwise it will continue to call `op(state, token)` with the values returned by the
+/// previous call to `op`, until a continuation token of `None` is returned
+///
+pub fn stream_paginated<F, Fut, S, T>(state: S, op: F) -> impl Stream<Item = Result<T>>
+where
+ F: Fn(S, Option<String>) -> Fut + Copy,
+ Fut: Future<Output = Result<(T, S, Option<String>)>>,
+{
+ enum PaginationState<T> {
+ Start(T),
+ HasMore(T, String),
+ Done,
+ }
+
+ futures::stream::unfold(PaginationState::Start(state), move |state| async move {
+ let (s, page_token) = match state {
+ PaginationState::Start(s) => (s, None),
+ PaginationState::HasMore(s, page_token) => (s, Some(page_token)),
+ PaginationState::Done => {
+ return None;
+ }
+ };
+
+ let (resp, s, continuation) = match op(s, page_token).await {
+ Ok(resp) => resp,
+ Err(e) => return Some((Err(e), PaginationState::Done)),
+ };
+
+ let next_state = match continuation {
+ Some(token) => PaginationState::HasMore(s, token),
+ None => PaginationState::Done,
+ };
+
+ Some((Ok(resp), next_state))
+ })
+}
diff --git a/object_store/src/client/token.rs b/object_store/src/client/token.rs
index a56a29462..2ff28616e 100644
--- a/object_store/src/client/token.rs
+++ b/object_store/src/client/token.rs
@@ -30,11 +30,19 @@ pub struct TemporaryToken<T> {
/// Provides [`TokenCache::get_or_insert_with`] which can be used to cache a
/// [`TemporaryToken`] based on its expiry
-#[derive(Debug, Default)]
+#[derive(Debug)]
pub struct TokenCache<T> {
cache: Mutex<Option<TemporaryToken<T>>>,
}
+impl<T> Default for TokenCache<T> {
+ fn default() -> Self {
+ Self {
+ cache: Default::default(),
+ }
+ }
+}
+
impl<T: Clone + Send> TokenCache<T> {
pub async fn get_or_insert_with<F, Fut, E>(&self, f: F) -> Result<T, E>
where
diff --git a/object_store/src/gcp.rs b/object_store/src/gcp.rs
index 0dc5a956a..c9bb63359 100644
--- a/object_store/src/gcp.rs
+++ b/object_store/src/gcp.rs
@@ -38,7 +38,6 @@ use std::sync::Arc;
use async_trait::async_trait;
use bytes::{Buf, Bytes};
use chrono::{DateTime, Utc};
-use futures::future::BoxFuture;
use futures::{stream::BoxStream, StreamExt, TryStreamExt};
use percent_encoding::{percent_encode, NON_ALPHANUMERIC};
use reqwest::header::RANGE;
@@ -46,6 +45,7 @@ use reqwest::{header, Client, Method, Response, StatusCode};
use snafu::{ResultExt, Snafu};
use tokio::io::AsyncWrite;
+use crate::client::pagination::stream_paginated;
use crate::client::retry::RetryExt;
use crate::{
client::{oauth::OAuthProvider, token::TokenCache},
@@ -476,44 +476,16 @@ impl GoogleCloudStorageClient {
&self,
prefix: Option<&Path>,
delimiter: bool,
- ) -> Result<BoxStream<'_, Result<ListResponse>>> {
+ ) -> BoxStream<'_, Result<ListResponse>> {
let prefix = format_prefix(prefix);
-
- enum ListState {
- Start,
- HasMore(String),
- Done,
- }
-
- Ok(futures::stream::unfold(ListState::Start, move |state| {
- let prefix = prefix.clone();
-
- async move {
- let page_token = match &state {
- ListState::Start => None,
- ListState::HasMore(page_token) => Some(page_token.as_str()),
- ListState::Done => {
- return None;
- }
- };
-
- let resp = match self
- .list_request(prefix.as_deref(), delimiter, page_token)
- .await
- {
- Ok(resp) => resp,
- Err(e) => return Some((Err(e), state)),
- };
-
- let next_state = match &resp.next_page_token {
- Some(token) => ListState::HasMore(token.clone()),
- None => ListState::Done,
- };
-
- Some((Ok(resp), next_state))
- }
+ stream_paginated(prefix, move |prefix, token| async move {
+ let mut r = self
+ .list_request(prefix.as_deref(), delimiter, token.as_deref())
+ .await?;
+ let next_token = r.next_page_token.take();
+ Ok((r, prefix, next_token))
})
- .boxed())
+ .boxed()
}
}
@@ -544,116 +516,105 @@ struct GCSMultipartUpload {
multipart_id: MultipartId,
}
+#[async_trait]
impl CloudMultiPartUploadImpl for GCSMultipartUpload {
/// Upload an object part <https://cloud.google.com/storage/docs/xml-api/put-object-multipart>
- fn put_multipart_part(
+ async fn put_multipart_part(
&self,
buf: Vec<u8>,
part_idx: usize,
- ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>> {
+ ) -> Result<UploadPart, io::Error> {
let upload_id = self.multipart_id.clone();
let url = format!(
"{}/{}/{}",
self.client.base_url, self.client.bucket_name_encoded, self.encoded_path
);
- let client = Arc::clone(&self.client);
-
- Box::pin(async move {
- let token = client
- .get_token()
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
-
- let response = client
- .client
- .request(Method::PUT, &url)
- .bearer_auth(token)
- .query(&[
- ("partNumber", format!("{}", part_idx + 1)),
- ("uploadId", upload_id),
- ])
- .header(header::CONTENT_TYPE, "application/octet-stream")
- .header(header::CONTENT_LENGTH, format!("{}", buf.len()))
- .body(buf)
- .send_retry(&client.retry_config)
- .await
- .map_err(reqwest_error_as_io)?
- .error_for_status()
- .map_err(reqwest_error_as_io)?;
-
- let content_id = response
- .headers()
- .get("ETag")
- .ok_or_else(|| {
- io::Error::new(
- io::ErrorKind::InvalidData,
- "response headers missing ETag",
- )
- })?
- .to_str()
- .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
- .to_string();
- Ok((part_idx, UploadPart { content_id }))
- })
+ let token = self
+ .client
+ .get_token()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ let response = self
+ .client
+ .client
+ .request(Method::PUT, &url)
+ .bearer_auth(token)
+ .query(&[
+ ("partNumber", format!("{}", part_idx + 1)),
+ ("uploadId", upload_id),
+ ])
+ .header(header::CONTENT_TYPE, "application/octet-stream")
+ .header(header::CONTENT_LENGTH, format!("{}", buf.len()))
+ .body(buf)
+ .send_retry(&self.client.retry_config)
+ .await
+ .map_err(reqwest_error_as_io)?
+ .error_for_status()
+ .map_err(reqwest_error_as_io)?;
+
+ let content_id = response
+ .headers()
+ .get("ETag")
+ .ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::InvalidData,
+ "response headers missing ETag",
+ )
+ })?
+ .to_str()
+ .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?
+ .to_string();
+
+ Ok(UploadPart { content_id })
}
/// Complete a multipart upload <https://cloud.google.com/storage/docs/xml-api/post-object-complete>
- fn complete(
- &self,
- completed_parts: Vec<Option<UploadPart>>,
- ) -> BoxFuture<'static, Result<(), io::Error>> {
- let client = Arc::clone(&self.client);
+ async fn complete(&self, completed_parts: Vec<UploadPart>) -> Result<(), io::Error> {
let upload_id = self.multipart_id.clone();
let url = format!(
"{}/{}/{}",
self.client.base_url, self.client.bucket_name_encoded, self.encoded_path
);
- Box::pin(async move {
- let parts: Vec<MultipartPart> = completed_parts
- .into_iter()
- .enumerate()
- .map(|(part_number, maybe_part)| match maybe_part {
- Some(part) => Ok(MultipartPart {
- e_tag: part.content_id,
- part_number: part_number + 1,
- }),
- None => Err(io::Error::new(
- io::ErrorKind::Other,
- format!("Missing information for upload part {:?}", part_number),
- )),
- })
- .collect::<Result<Vec<MultipartPart>, io::Error>>()?;
-
- let token = client
- .get_token()
- .await
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
-
- let upload_info = CompleteMultipartUpload { parts };
-
- let data = quick_xml::se::to_string(&upload_info)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
- // We cannot disable the escaping that transforms "/" to ""e;" :(
- // https://github.com/tafia/quick-xml/issues/362
- // https://github.com/tafia/quick-xml/issues/350
- .replace(""", "\"");
-
- client
- .client
- .request(Method::POST, &url)
- .bearer_auth(token)
- .query(&[("uploadId", upload_id)])
- .body(data)
- .send_retry(&client.retry_config)
- .await
- .map_err(reqwest_error_as_io)?
- .error_for_status()
- .map_err(reqwest_error_as_io)?;
-
- Ok(())
- })
+ let parts = completed_parts
+ .into_iter()
+ .enumerate()
+ .map(|(part_number, part)| MultipartPart {
+ e_tag: part.content_id,
+ part_number: part_number + 1,
+ })
+ .collect();
+
+ let token = self
+ .client
+ .get_token()
+ .await
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
+
+ let upload_info = CompleteMultipartUpload { parts };
+
+ let data = quick_xml::se::to_string(&upload_info)
+ .map_err(|err| io::Error::new(io::ErrorKind::Other, err))?
+ // We cannot disable the escaping that transforms "/" to ""e;" :(
+ // https://github.com/tafia/quick-xml/issues/362
+ // https://github.com/tafia/quick-xml/issues/350
+ .replace(""", "\"");
+
+ self.client
+ .client
+ .request(Method::POST, &url)
+ .bearer_auth(token)
+ .query(&[("uploadId", upload_id)])
+ .body(data)
+ .send_retry(&self.client.retry_config)
+ .await
+ .map_err(reqwest_error_as_io)?
+ .error_for_status()
+ .map_err(reqwest_error_as_io)?;
+
+ Ok(())
}
}
@@ -734,7 +695,7 @@ impl ObjectStore for GoogleCloudStorage {
) -> Result<BoxStream<'_, Result<ObjectMeta>>> {
let stream = self
.client
- .list_paginated(prefix, false)?
+ .list_paginated(prefix, false)
.map_ok(|r| {
futures::stream::iter(
r.items.into_iter().map(|x| convert_object_meta(&x)),
@@ -747,7 +708,7 @@ impl ObjectStore for GoogleCloudStorage {
}
async fn list_with_delimiter(&self, prefix: Option<&Path>) -> Result<ListResult> {
- let mut stream = self.client.list_paginated(prefix, true)?;
+ let mut stream = self.client.list_paginated(prefix, true);
let mut common_prefixes = BTreeSet::new();
let mut objects = Vec::new();
diff --git a/object_store/src/lib.rs b/object_store/src/lib.rs
index f7adedb26..374f5592e 100644
--- a/object_store/src/lib.rs
+++ b/object_store/src/lib.rs
@@ -165,10 +165,10 @@ pub mod memory;
pub mod path;
pub mod throttle;
-#[cfg(feature = "gcp")]
+#[cfg(any(feature = "gcp", feature = "aws"))]
mod client;
-#[cfg(feature = "gcp")]
+#[cfg(any(feature = "gcp", feature = "aws"))]
pub use client::{backoff::BackoffConfig, retry::RetryConfig};
#[cfg(any(feature = "azure", feature = "aws", feature = "gcp"))]
@@ -471,6 +471,16 @@ pub enum Error {
OAuth { source: client::oauth::Error },
}
+impl From<Error> for std::io::Error {
+ fn from(e: Error) -> Self {
+ let kind = match &e {
+ Error::NotFound { .. } => std::io::ErrorKind::NotFound,
+ _ => std::io::ErrorKind::Other,
+ };
+ Self::new(kind, e)
+ }
+}
+
#[cfg(test)]
mod test_util {
use super::*;
diff --git a/object_store/src/multipart.rs b/object_store/src/multipart.rs
index c16022d37..1985d8694 100644
--- a/object_store/src/multipart.rs
+++ b/object_store/src/multipart.rs
@@ -15,7 +15,8 @@
// specific language governing permissions and limitations
// under the License.
-use futures::{future::BoxFuture, stream::FuturesUnordered, Future, StreamExt};
+use async_trait::async_trait;
+use futures::{stream::FuturesUnordered, Future, StreamExt};
use std::{io, pin::Pin, sync::Arc, task::Poll};
use tokio::io::AsyncWrite;
@@ -26,23 +27,19 @@ type BoxedTryFuture<T> = Pin<Box<dyn Future<Output = Result<T, io::Error>> + Sen
/// A trait that can be implemented by cloud-based object stores
/// and used in combination with [`CloudMultiPartUpload`] to provide
/// multipart upload support
-///
-/// Note: this does not use AsyncTrait as the lifetimes are difficult to manage
-pub(crate) trait CloudMultiPartUploadImpl {
+#[async_trait]
+pub(crate) trait CloudMultiPartUploadImpl: 'static {
/// Upload a single part
- fn put_multipart_part(
+ async fn put_multipart_part(
&self,
buf: Vec<u8>,
part_idx: usize,
- ) -> BoxFuture<'static, Result<(usize, UploadPart), io::Error>>;
+ ) -> Result<UploadPart, io::Error>;
/// Complete the upload with the provided parts
///
/// `completed_parts` is in order of part number
- fn complete(
- &self,
- completed_parts: Vec<Option<UploadPart>>,
- ) -> BoxFuture<'static, Result<(), io::Error>>;
+ async fn complete(&self, completed_parts: Vec<UploadPart>) -> Result<(), io::Error>;
}
#[derive(Debug, Clone)]
@@ -128,10 +125,12 @@ where
self.current_buffer.extend_from_slice(buf);
let out_buffer = std::mem::take(&mut self.current_buffer);
- let task = self
- .inner
- .put_multipart_part(out_buffer, self.current_part_idx);
- self.tasks.push(task);
+ let inner = Arc::clone(&self.inner);
+ let part_idx = self.current_part_idx;
+ self.tasks.push(Box::pin(async move {
+ let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?;
+ Ok((part_idx, upload_part))
+ }));
self.current_part_idx += 1;
// We need to poll immediately after adding to setup waker
@@ -157,10 +156,12 @@ where
// If current_buffer is not empty, see if it can be submitted
if !self.current_buffer.is_empty() && self.tasks.len() < self.max_concurrency {
let out_buffer: Vec<u8> = std::mem::take(&mut self.current_buffer);
- let task = self
- .inner
- .put_multipart_part(out_buffer, self.current_part_idx);
- self.tasks.push(task);
+ let inner = Arc::clone(&self.inner);
+ let part_idx = self.current_part_idx;
+ self.tasks.push(Box::pin(async move {
+ let upload_part = inner.put_multipart_part(out_buffer, part_idx).await?;
+ Ok((part_idx, upload_part))
+ }));
}
self.as_mut().poll_tasks(cx)?;
@@ -185,10 +186,26 @@ where
// If shutdown task is not set, set it
let parts = std::mem::take(&mut self.completed_parts);
+ let parts = parts
+ .into_iter()
+ .enumerate()
+ .map(|(idx, part)| {
+ part.ok_or_else(|| {
+ io::Error::new(
+ io::ErrorKind::Other,
+ format!("Missing information for upload part {}", idx),
+ )
+ })
+ })
+ .collect::<Result<_, _>>()?;
+
let inner = Arc::clone(&self.inner);
- let completion_task = self
- .completion_task
- .get_or_insert_with(|| inner.complete(parts));
+ let completion_task = self.completion_task.get_or_insert_with(|| {
+ Box::pin(async move {
+ inner.complete(parts).await?;
+ Ok(())
+ })
+ });
Pin::new(completion_task).poll(cx)
}