You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/05/17 11:13:10 UTC
[arrow-rs] branch master updated: Standardise credentials API (#4223) (#4163) (#4225)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 695356111 Standardise credentials API (#4223) (#4163) (#4225)
695356111 is described below
commit 69535611176f95f302c10e25a98f5b49af683d8b
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Wed May 17 12:13:04 2023 +0100
Standardise credentials API (#4223) (#4163) (#4225)
* Standardise credentials API (#4223) (#4163)
* Clippy
* Allow HTTP metadata endpoint
---
object_store/src/aws/client.rs | 6 +-
object_store/src/aws/credential.rs | 91 ++++++++---------
object_store/src/aws/mod.rs | 60 +++++------
object_store/src/aws/profile.rs | 71 ++++++-------
object_store/src/azure/client.rs | 52 ++--------
object_store/src/azure/credential.rs | 131 ++++++++++++------------
object_store/src/azure/mod.rs | 65 ++++++------
object_store/src/client/mod.rs | 89 ++++++++++++++++-
object_store/src/gcp/credential.rs | 187 ++++++++++++++++++-----------------
object_store/src/gcp/mod.rs | 121 ++++++++++-------------
10 files changed, 461 insertions(+), 412 deletions(-)
diff --git a/object_store/src/aws/client.rs b/object_store/src/aws/client.rs
index 1cdf785e5..8ce743b31 100644
--- a/object_store/src/aws/client.rs
+++ b/object_store/src/aws/client.rs
@@ -16,8 +16,8 @@
// under the License.
use crate::aws::checksum::Checksum;
-use crate::aws::credential::{AwsCredential, CredentialExt, CredentialProvider};
-use crate::aws::{STORE, STRICT_PATH_ENCODE_SET};
+use crate::aws::credential::{AwsCredential, CredentialExt};
+use crate::aws::{AwsCredentialProvider, STORE, STRICT_PATH_ENCODE_SET};
use crate::client::list::ListResponse;
use crate::client::pagination::stream_paginated;
use crate::client::retry::RetryExt;
@@ -135,7 +135,7 @@ pub struct S3Config {
pub endpoint: String,
pub bucket: String,
pub bucket_endpoint: String,
- pub credentials: Box<dyn CredentialProvider>,
+ pub credentials: AwsCredentialProvider,
pub retry_config: RetryConfig,
pub client_options: ClientOptions,
pub sign_payload: bool,
diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs
index 9e047941a..47d681c63 100644
--- a/object_store/src/aws/credential.rs
+++ b/object_store/src/aws/credential.rs
@@ -18,12 +18,12 @@
use crate::aws::{STORE, STRICT_ENCODE_SET};
use crate::client::retry::RetryExt;
use crate::client::token::{TemporaryToken, TokenCache};
+use crate::client::TokenProvider;
use crate::util::hmac_sha256;
use crate::{Result, RetryConfig};
+use async_trait::async_trait;
use bytes::Buf;
use chrono::{DateTime, Utc};
-use futures::future::BoxFuture;
-use futures::TryFutureExt;
use percent_encoding::utf8_percent_encode;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, Method, Request, RequestBuilder, StatusCode};
@@ -41,10 +41,14 @@ static EMPTY_SHA256_HASH: &str =
"e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
static UNSIGNED_PAYLOAD_LITERAL: &str = "UNSIGNED-PAYLOAD";
-#[derive(Debug)]
+/// A set of AWS security credentials
+#[derive(Debug, Eq, PartialEq)]
pub struct AwsCredential {
+ /// AWS_ACCESS_KEY_ID
pub key_id: String,
+ /// AWS_SECRET_ACCESS_KEY
pub secret_key: String,
+ /// AWS_SESSION_TOKEN
pub token: Option<String>,
}
@@ -291,49 +295,31 @@ fn canonicalize_headers(header_map: &HeaderMap) -> (String, String) {
(signed_headers, canonical_headers)
}
-/// Provides credentials for use when signing requests
-pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
- fn get_credential(&self) -> BoxFuture<'_, Result<Arc<AwsCredential>>>;
-}
-
-/// A static set of credentials
-#[derive(Debug)]
-pub struct StaticCredentialProvider {
- pub credential: Arc<AwsCredential>,
-}
-
-impl CredentialProvider for StaticCredentialProvider {
- fn get_credential(&self) -> BoxFuture<'_, Result<Arc<AwsCredential>>> {
- Box::pin(futures::future::ready(Ok(Arc::clone(&self.credential))))
- }
-}
-
/// Credentials sourced from the instance metadata service
///
/// <https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html>
#[derive(Debug)]
pub struct InstanceCredentialProvider {
pub cache: TokenCache<Arc<AwsCredential>>,
- pub client: Client,
- pub retry_config: RetryConfig,
pub imdsv1_fallback: bool,
pub metadata_endpoint: String,
}
-impl CredentialProvider for InstanceCredentialProvider {
- fn get_credential(&self) -> BoxFuture<'_, Result<Arc<AwsCredential>>> {
- Box::pin(self.cache.get_or_insert_with(|| {
- instance_creds(
- &self.client,
- &self.retry_config,
- &self.metadata_endpoint,
- self.imdsv1_fallback,
- )
+#[async_trait]
+impl TokenProvider for InstanceCredentialProvider {
+ type Credential = AwsCredential;
+
+ async fn fetch_token(
+ &self,
+ client: &Client,
+ retry: &RetryConfig,
+ ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
+ instance_creds(client, retry, &self.metadata_endpoint, self.imdsv1_fallback)
+ .await
.map_err(|source| crate::Error::Generic {
store: STORE,
source,
})
- }))
}
}
@@ -342,31 +328,34 @@ impl CredentialProvider for InstanceCredentialProvider {
/// <https://docs.aws.amazon.com/eks/latest/userguide/iam-roles-for-service-accounts-technical-overview.html>
#[derive(Debug)]
pub struct WebIdentityProvider {
- pub cache: TokenCache<Arc<AwsCredential>>,
pub token_path: String,
pub role_arn: String,
pub session_name: String,
pub endpoint: String,
- pub client: Client,
- pub retry_config: RetryConfig,
}
-impl CredentialProvider for WebIdentityProvider {
- fn get_credential(&self) -> BoxFuture<'_, Result<Arc<AwsCredential>>> {
- Box::pin(self.cache.get_or_insert_with(|| {
- web_identity(
- &self.client,
- &self.retry_config,
- &self.token_path,
- &self.role_arn,
- &self.session_name,
- &self.endpoint,
- )
- .map_err(|source| crate::Error::Generic {
- store: STORE,
- source,
- })
- }))
+#[async_trait]
+impl TokenProvider for WebIdentityProvider {
+ type Credential = AwsCredential;
+
+ async fn fetch_token(
+ &self,
+ client: &Client,
+ retry: &RetryConfig,
+ ) -> Result<TemporaryToken<Arc<AwsCredential>>> {
+ web_identity(
+ client,
+ retry,
+ &self.token_path,
+ &self.role_arn,
+ &self.session_name,
+ &self.endpoint,
+ )
+ .await
+ .map_err(|source| crate::Error::Generic {
+ store: STORE,
+ source,
+ })
}
}
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index 428e013f4..ddb9dc799 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -48,11 +48,13 @@ use url::Url;
pub use crate::aws::checksum::Checksum;
use crate::aws::client::{S3Client, S3Config};
use crate::aws::credential::{
- AwsCredential, CredentialProvider, InstanceCredentialProvider,
- StaticCredentialProvider, WebIdentityProvider,
+ AwsCredential, InstanceCredentialProvider, WebIdentityProvider,
};
use crate::client::header::header_meta;
-use crate::client::ClientConfigKey;
+use crate::client::{
+ ClientConfigKey, CredentialProvider, StaticCredentialProvider,
+ TokenCredentialProvider,
+};
use crate::config::ConfigValue;
use crate::multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart};
use crate::{
@@ -83,6 +85,8 @@ const STRICT_PATH_ENCODE_SET: percent_encoding::AsciiSet = STRICT_ENCODE_SET.rem
const STORE: &str = "S3";
+type AwsCredentialProvider = Arc<dyn CredentialProvider<Credential = AwsCredential>>;
+
/// Default metadata endpoint
static METADATA_ENDPOINT: &str = "http://169.254.169.254";
@@ -1001,13 +1005,12 @@ impl AmazonS3Builder {
let credentials = match (self.access_key_id, self.secret_access_key, self.token) {
(Some(key_id), Some(secret_key), token) => {
info!("Using Static credential provider");
- Box::new(StaticCredentialProvider {
- credential: Arc::new(AwsCredential {
- key_id,
- secret_key,
- token,
- }),
- }) as _
+ let credential = AwsCredential {
+ key_id,
+ secret_key,
+ token,
+ };
+ Arc::new(StaticCredentialProvider::new(credential)) as _
}
(None, Some(_), _) => return Err(Error::MissingAccessKeyId.into()),
(Some(_), None, _) => return Err(Error::MissingSecretAccessKey.into()),
@@ -1031,15 +1034,18 @@ impl AmazonS3Builder {
.with_allow_http(false)
.client()?;
- Box::new(WebIdentityProvider {
- cache: Default::default(),
+ let token = WebIdentityProvider {
token_path,
session_name,
role_arn,
endpoint,
+ };
+
+ Arc::new(TokenCredentialProvider::new(
+ token,
client,
- retry_config: self.retry_config.clone(),
- }) as _
+ self.retry_config.clone(),
+ )) as _
}
_ => match self.profile {
Some(profile) => {
@@ -1049,19 +1055,20 @@ impl AmazonS3Builder {
None => {
info!("Using Instance credential provider");
- // The instance metadata endpoint is access over HTTP
- let client_options =
- self.client_options.clone().with_allow_http(true);
-
- Box::new(InstanceCredentialProvider {
+ let token = InstanceCredentialProvider {
cache: Default::default(),
- client: client_options.client()?,
- retry_config: self.retry_config.clone(),
imdsv1_fallback: self.imdsv1_fallback.get()?,
metadata_endpoint: self
.metadata_endpoint
.unwrap_or_else(|| METADATA_ENDPOINT.into()),
- }) as _
+ };
+
+ Arc::new(TokenCredentialProvider::new(
+ token,
+ // The instance metadata endpoint is access over HTTP
+ self.client_options.clone().with_allow_http(true).client()?,
+ self.retry_config.clone(),
+ )) as _
}
},
},
@@ -1114,11 +1121,8 @@ fn profile_region(profile: String) -> Option<String> {
}
#[cfg(feature = "aws_profile")]
-fn profile_credentials(
- profile: String,
- region: String,
-) -> Result<Box<dyn CredentialProvider>> {
- Ok(Box::new(profile::ProfileProvider::new(
+fn profile_credentials(profile: String, region: String) -> Result<AwsCredentialProvider> {
+ Ok(Arc::new(profile::ProfileProvider::new(
profile,
Some(region),
)))
@@ -1133,7 +1137,7 @@ fn profile_region(_profile: String) -> Option<String> {
fn profile_credentials(
_profile: String,
_region: String,
-) -> Result<Box<dyn CredentialProvider>> {
+) -> Result<AwsCredentialProvider> {
Err(Error::MissingProfileFeature.into())
}
diff --git a/object_store/src/aws/profile.rs b/object_store/src/aws/profile.rs
index a88824c79..3fc080564 100644
--- a/object_store/src/aws/profile.rs
+++ b/object_store/src/aws/profile.rs
@@ -17,6 +17,7 @@
#![cfg(feature = "aws_profile")]
+use async_trait::async_trait;
use aws_config::meta::region::ProvideRegion;
use aws_config::profile::profile_file::ProfileFiles;
use aws_config::profile::ProfileFileCredentialsProvider;
@@ -24,14 +25,13 @@ use aws_config::profile::ProfileFileRegionProvider;
use aws_config::provider_config::ProviderConfig;
use aws_credential_types::provider::ProvideCredentials;
use aws_types::region::Region;
-use futures::future::BoxFuture;
use std::sync::Arc;
use std::time::Instant;
use std::time::SystemTime;
-use crate::aws::credential::CredentialProvider;
use crate::aws::AwsCredential;
use crate::client::token::{TemporaryToken, TokenCache};
+use crate::client::CredentialProvider;
use crate::Result;
#[cfg(test)]
@@ -91,38 +91,43 @@ impl ProfileProvider {
}
}
+#[async_trait]
impl CredentialProvider for ProfileProvider {
- fn get_credential(&self) -> BoxFuture<'_, Result<Arc<AwsCredential>>> {
- Box::pin(self.cache.get_or_insert_with(move || async move {
- let region = self.region.clone().map(Region::new);
-
- let config = ProviderConfig::default().with_region(region);
-
- let credentials = ProfileFileCredentialsProvider::builder()
- .configure(&config)
- .profile_name(&self.name)
- .build();
-
- let c = credentials.provide_credentials().await.map_err(|source| {
- crate::Error::Generic {
- store: "S3",
- source: Box::new(source),
- }
- })?;
- let t_now = SystemTime::now();
- let expiry = c
- .expiry()
- .and_then(|e| e.duration_since(t_now).ok())
- .map(|ttl| Instant::now() + ttl);
-
- Ok(TemporaryToken {
- token: Arc::new(AwsCredential {
- key_id: c.access_key_id().to_string(),
- secret_key: c.secret_access_key().to_string(),
- token: c.session_token().map(ToString::to_string),
- }),
- expiry,
+ type Credential = AwsCredential;
+
+ async fn get_credential(&self) -> Result<Arc<AwsCredential>> {
+ self.cache
+ .get_or_insert_with(move || async move {
+ let region = self.region.clone().map(Region::new);
+
+ let config = ProviderConfig::default().with_region(region);
+
+ let credentials = ProfileFileCredentialsProvider::builder()
+ .configure(&config)
+ .profile_name(&self.name)
+ .build();
+
+ let c = credentials.provide_credentials().await.map_err(|source| {
+ crate::Error::Generic {
+ store: "S3",
+ source: Box::new(source),
+ }
+ })?;
+ let t_now = SystemTime::now();
+ let expiry = c
+ .expiry()
+ .and_then(|e| e.duration_since(t_now).ok())
+ .map(|ttl| Instant::now() + ttl);
+
+ Ok(TemporaryToken {
+ token: Arc::new(AwsCredential {
+ key_id: c.access_key_id().to_string(),
+ secret_key: c.secret_access_key().to_string(),
+ token: c.session_token().map(ToString::to_string),
+ }),
+ expiry,
+ })
})
- }))
+ .await
}
}
diff --git a/object_store/src/azure/client.rs b/object_store/src/azure/client.rs
index 893e261fe..5f165c007 100644
--- a/object_store/src/azure/client.rs
+++ b/object_store/src/azure/client.rs
@@ -15,9 +15,9 @@
// specific language governing permissions and limitations
// under the License.
-use super::credential::{AzureCredential, CredentialProvider};
+use super::credential::AzureCredential;
use crate::azure::credential::*;
-use crate::azure::STORE;
+use crate::azure::{AzureCredentialProvider, STORE};
use crate::client::pagination::stream_paginated;
use crate::client::retry::RetryExt;
use crate::client::GetOptionsExt;
@@ -40,6 +40,7 @@ use reqwest::{
use serde::{Deserialize, Serialize};
use snafu::{ResultExt, Snafu};
use std::collections::HashMap;
+use std::sync::Arc;
use url::Url;
/// A specialized `Error` for object store-related errors
@@ -101,10 +102,10 @@ impl From<Error> for crate::Error {
/// Configuration for [AzureClient]
#[derive(Debug)]
-pub struct AzureConfig {
+pub(crate) struct AzureConfig {
pub account: String,
pub container: String,
- pub credentials: CredentialProvider,
+ pub credentials: AzureCredentialProvider,
pub retry_config: RetryConfig,
pub service: Url,
pub is_emulator: bool,
@@ -143,45 +144,8 @@ impl AzureClient {
&self.config
}
- async fn get_credential(&self) -> Result<AzureCredential> {
- match &self.config.credentials {
- CredentialProvider::AccessKey(key) => {
- Ok(AzureCredential::AccessKey(key.to_owned()))
- }
- CredentialProvider::BearerToken(token) => {
- Ok(AzureCredential::AuthorizationToken(
- // we do the conversion to a HeaderValue here, since it is fallible
- // and we want to use it in an infallible function
- HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| {
- crate::Error::Generic {
- store: STORE,
- source: Box::new(err),
- }
- })?,
- ))
- }
- CredentialProvider::TokenCredential(cache, cred) => {
- let token = cache
- .get_or_insert_with(|| {
- cred.fetch_token(&self.client, &self.config.retry_config)
- })
- .await
- .context(AuthorizationSnafu)?;
- Ok(AzureCredential::AuthorizationToken(
- // we do the conversion to a HeaderValue here, since it is fallible
- // and we want to use it in an infallible function
- HeaderValue::from_str(&format!("Bearer {token}")).map_err(|err| {
- crate::Error::Generic {
- store: STORE,
- source: Box::new(err),
- }
- })?,
- ))
- }
- CredentialProvider::SASToken(sas) => {
- Ok(AzureCredential::SASToken(sas.clone()))
- }
- }
+ async fn get_credential(&self) -> Result<Arc<AzureCredential>> {
+ self.config.credentials.get_credential().await
}
/// Make an Azure PUT request <https://docs.microsoft.com/en-us/rest/api/storageservices/put-blob>
@@ -308,7 +272,7 @@ impl AzureClient {
// If using SAS authorization must include the headers in the URL
// <https://docs.microsoft.com/en-us/rest/api/storageservices/copy-blob#request-headers>
- if let AzureCredential::SASToken(pairs) = &credential {
+ if let AzureCredential::SASToken(pairs) = credential.as_ref() {
source.query_pairs_mut().extend_pairs(pairs);
}
diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs
index 8130df636..fd7538924 100644
--- a/object_store/src/azure/credential.rs
+++ b/object_store/src/azure/credential.rs
@@ -15,10 +15,13 @@
// specific language governing permissions and limitations
// under the License.
+use crate::azure::STORE;
use crate::client::retry::RetryExt;
use crate::client::token::{TemporaryToken, TokenCache};
+use crate::client::{CredentialProvider, TokenProvider};
use crate::util::hmac_sha256;
use crate::RetryConfig;
+use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use chrono::{DateTime, Utc};
@@ -36,6 +39,7 @@ use snafu::{ResultExt, Snafu};
use std::borrow::Cow;
use std::process::Command;
use std::str;
+use std::sync::Arc;
use std::time::{Duration, Instant};
use url::Url;
@@ -81,19 +85,30 @@ pub enum Error {
pub type Result<T, E = Error> = std::result::Result<T, E>;
-/// Provides credentials for use when signing requests
-#[derive(Debug)]
-pub enum CredentialProvider {
- AccessKey(String),
- BearerToken(String),
- SASToken(Vec<(String, String)>),
- TokenCredential(TokenCache<String>, Box<dyn TokenCredential>),
+impl From<Error> for crate::Error {
+ fn from(value: Error) -> Self {
+ Self::Generic {
+ store: STORE,
+ source: Box::new(value),
+ }
+ }
}
-pub(crate) enum AzureCredential {
+/// An Azure storage credential
+#[derive(Debug, Eq, PartialEq)]
+pub enum AzureCredential {
+ /// A shared access key
+ ///
+ /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-shared-key>
AccessKey(String),
+ /// A shared access signature
+ ///
+ /// <https://learn.microsoft.com/en-us/rest/api/storageservices/delegate-access-with-shared-access-signature>
SASToken(Vec<(String, String)>),
- AuthorizationToken(HeaderValue),
+ /// An authorization token
+ ///
+ /// <https://learn.microsoft.com/en-us/rest/api/storageservices/authorize-with-azure-active-directory>
+ BearerToken(String),
}
/// A list of known Azure authority hosts
@@ -155,9 +170,7 @@ impl CredentialExt for RequestBuilder {
Self::from_parts(client, request)
}
- AzureCredential::AuthorizationToken(token) => {
- self.header(AUTHORIZATION, token)
- }
+ AzureCredential::BearerToken(token) => self.bearer_auth(token),
AzureCredential::SASToken(query_pairs) => self.query(&query_pairs),
}
}
@@ -291,15 +304,6 @@ fn lexy_sort<'a>(
values
}
-#[async_trait::async_trait]
-pub trait TokenCredential: std::fmt::Debug + Send + Sync + 'static {
- async fn fetch_token(
- &self,
- client: &Client,
- retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>>;
-}
-
#[derive(Deserialize, Debug)]
struct TokenResponse {
access_token: String,
@@ -338,13 +342,15 @@ impl ClientSecretOAuthProvider {
}
#[async_trait::async_trait]
-impl TokenCredential for ClientSecretOAuthProvider {
+impl TokenProvider for ClientSecretOAuthProvider {
+ type Credential = AzureCredential;
+
/// Fetch a token
async fn fetch_token(
&self,
client: &Client,
retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>> {
+ ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
let response: TokenResponse = client
.request(Method::POST, &self.token_url)
.header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON))
@@ -361,12 +367,10 @@ impl TokenCredential for ClientSecretOAuthProvider {
.await
.context(TokenResponseBodySnafu)?;
- let token = TemporaryToken {
- token: response.access_token,
+ Ok(TemporaryToken {
+ token: Arc::new(AzureCredential::BearerToken(response.access_token)),
expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
- };
-
- Ok(token)
+ })
}
}
@@ -397,7 +401,6 @@ pub struct ImdsManagedIdentityProvider {
client_id: Option<String>,
object_id: Option<String>,
msi_res_id: Option<String>,
- client: Client,
}
impl ImdsManagedIdentityProvider {
@@ -407,7 +410,6 @@ impl ImdsManagedIdentityProvider {
object_id: Option<String>,
msi_res_id: Option<String>,
msi_endpoint: Option<String>,
- client: Client,
) -> Self {
let msi_endpoint = msi_endpoint.unwrap_or_else(|| {
"http://169.254.169.254/metadata/identity/oauth2/token".to_owned()
@@ -418,19 +420,20 @@ impl ImdsManagedIdentityProvider {
client_id,
object_id,
msi_res_id,
- client,
}
}
}
#[async_trait::async_trait]
-impl TokenCredential for ImdsManagedIdentityProvider {
+impl TokenProvider for ImdsManagedIdentityProvider {
+ type Credential = AzureCredential;
+
/// Fetch a token
async fn fetch_token(
&self,
- _client: &Client,
+ client: &Client,
retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>> {
+ ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
let mut query_items = vec![
("api-version", MSI_API_VERSION),
("resource", AZURE_STORAGE_RESOURCE),
@@ -450,8 +453,7 @@ impl TokenCredential for ImdsManagedIdentityProvider {
query_items.push((key, value));
}
- let mut builder = self
- .client
+ let mut builder = client
.request(Method::GET, &self.msi_endpoint)
.header("metadata", "true")
.query(&query_items);
@@ -468,12 +470,10 @@ impl TokenCredential for ImdsManagedIdentityProvider {
.await
.context(TokenResponseBodySnafu)?;
- let token = TemporaryToken {
- token: response.access_token,
+ Ok(TemporaryToken {
+ token: Arc::new(AzureCredential::BearerToken(response.access_token)),
expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
- };
-
- Ok(token)
+ })
}
}
@@ -511,13 +511,15 @@ impl WorkloadIdentityOAuthProvider {
}
#[async_trait::async_trait]
-impl TokenCredential for WorkloadIdentityOAuthProvider {
+impl TokenProvider for WorkloadIdentityOAuthProvider {
+ type Credential = AzureCredential;
+
/// Fetch a token
async fn fetch_token(
&self,
client: &Client,
retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>> {
+ ) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
let token_str = std::fs::read_to_string(&self.federated_token_file)
.map_err(|_| Error::FederatedTokenFile)?;
@@ -542,12 +544,10 @@ impl TokenCredential for WorkloadIdentityOAuthProvider {
.await
.context(TokenResponseBodySnafu)?;
- let token = TemporaryToken {
- token: response.access_token,
+ Ok(TemporaryToken {
+ token: Arc::new(AzureCredential::BearerToken(response.access_token)),
expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
- };
-
- Ok(token)
+ })
}
}
@@ -585,23 +585,16 @@ struct AzureCliTokenResponse {
#[derive(Default, Debug)]
pub struct AzureCliCredential {
- _private: (),
+ cache: TokenCache<Arc<AzureCredential>>,
}
impl AzureCliCredential {
pub fn new() -> Self {
Self::default()
}
-}
-#[async_trait::async_trait]
-impl TokenCredential for AzureCliCredential {
/// Fetch a token
- async fn fetch_token(
- &self,
- _client: &Client,
- _retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>> {
+ async fn fetch_token(&self) -> Result<TemporaryToken<Arc<AzureCredential>>> {
// on window az is a cmd and it should be called like this
// see https://doc.rust-lang.org/nightly/std/process/struct.Command.html
let program = if cfg!(target_os = "windows") {
@@ -642,7 +635,9 @@ impl TokenCredential for AzureCliCredential {
let duration = token_response.expires_on.naive_local()
- chrono::Local::now().naive_local();
Ok(TemporaryToken {
- token: token_response.access_token,
+ token: Arc::new(AzureCredential::BearerToken(
+ token_response.access_token,
+ )),
expiry: Some(
Instant::now()
+ duration.to_std().map_err(|_| Error::AzureCli {
@@ -669,6 +664,15 @@ impl TokenCredential for AzureCliCredential {
}
}
+#[async_trait]
+impl CredentialProvider for AzureCliCredential {
+ type Credential = AzureCredential;
+
+ async fn get_credential(&self) -> crate::Result<Arc<Self::Credential>> {
+ Ok(self.cache.get_or_insert_with(|| self.fetch_token()).await?)
+ }
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -723,7 +727,6 @@ mod tests {
None,
None,
Some(format!("{endpoint}/metadata/identity/oauth2/token")),
- client.clone(),
);
let token = credential
@@ -731,7 +734,10 @@ mod tests {
.await
.unwrap();
- assert_eq!(&token.token, "TOKEN");
+ assert_eq!(
+ token.token.as_ref(),
+ &AzureCredential::BearerToken("TOKEN".into())
+ );
}
#[tokio::test]
@@ -779,6 +785,9 @@ mod tests {
.await
.unwrap();
- assert_eq!(&token.token, "TOKEN");
+ assert_eq!(
+ token.token.as_ref(),
+ &AzureCredential::BearerToken("TOKEN".into())
+ );
}
}
diff --git a/object_store/src/azure/mod.rs b/object_store/src/azure/mod.rs
index 0f8dae00c..6dc14cfb5 100644
--- a/object_store/src/azure/mod.rs
+++ b/object_store/src/azure/mod.rs
@@ -27,7 +27,6 @@
//! a way to drop old blocks. Instead unused blocks are automatically cleaned up
//! after 7 days.
use self::client::{BlockId, BlockList};
-use crate::client::token::TokenCache;
use crate::{
multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
path::Path,
@@ -49,14 +48,20 @@ use std::{collections::BTreeSet, str::FromStr};
use tokio::io::AsyncWrite;
use url::Url;
+use crate::azure::credential::AzureCredential;
use crate::client::header::header_meta;
-use crate::client::ClientConfigKey;
+use crate::client::{
+ ClientConfigKey, CredentialProvider, StaticCredentialProvider,
+ TokenCredentialProvider,
+};
use crate::config::ConfigValue;
pub use credential::authority_hosts;
mod client;
mod credential;
+type AzureCredentialProvider = Arc<dyn CredentialProvider<Credential = AzureCredential>>;
+
const STORE: &str = "MicrosoftAzure";
/// The well-known account used by Azurite and the legacy Azure Storage Emulator.
@@ -101,12 +106,6 @@ enum Error {
#[snafu(display("Container name must be specified"))]
MissingContainerName {},
- #[snafu(display("At least one authorization option must be specified"))]
- MissingCredentials {},
-
- #[snafu(display("Azure credential error: {}", source), context(false))]
- Credential { source: credential::Error },
-
#[snafu(display(
"Unknown url scheme cannot be parsed into storage location: {}",
scheme
@@ -913,6 +912,9 @@ impl MicrosoftAzureBuilder {
}
let container = self.container_name.ok_or(Error::MissingContainerName {})?;
+ let static_creds = |credential: AzureCredential| -> AzureCredentialProvider {
+ Arc::new(StaticCredentialProvider::new(credential))
+ };
let (is_emulator, storage_url, auth, account) = if self.use_emulator.get()? {
let account_name = self
@@ -924,7 +926,8 @@ impl MicrosoftAzureBuilder {
let account_key = self
.access_key
.unwrap_or_else(|| EMULATOR_ACCOUNT_KEY.to_string());
- let credential = credential::CredentialProvider::AccessKey(account_key);
+
+ let credential = static_creds(AzureCredential::AccessKey(account_key));
self.client_options = self.client_options.with_allow_http(true);
(true, url, credential, account_name)
@@ -933,10 +936,11 @@ impl MicrosoftAzureBuilder {
let account_url = format!("https://{}.blob.core.windows.net", &account_name);
let url = Url::parse(&account_url)
.context(UnableToParseUrlSnafu { url: account_url })?;
+
let credential = if let Some(bearer_token) = self.bearer_token {
- credential::CredentialProvider::BearerToken(bearer_token)
+ static_creds(AzureCredential::BearerToken(bearer_token))
} else if let Some(access_key) = self.access_key {
- credential::CredentialProvider::AccessKey(access_key)
+ static_creds(AzureCredential::AccessKey(access_key))
} else if let (Some(client_id), Some(tenant_id), Some(federated_token_file)) =
(&self.client_id, &self.tenant_id, self.federated_token_file)
{
@@ -946,10 +950,11 @@ impl MicrosoftAzureBuilder {
tenant_id,
self.authority_host,
);
- credential::CredentialProvider::TokenCredential(
- TokenCache::default(),
- Box::new(client_credential),
- )
+ Arc::new(TokenCredentialProvider::new(
+ client_credential,
+ self.client_options.client()?,
+ self.retry_config.clone(),
+ )) as _
} else if let (Some(client_id), Some(client_secret), Some(tenant_id)) =
(&self.client_id, self.client_secret, &self.tenant_id)
{
@@ -959,33 +964,29 @@ impl MicrosoftAzureBuilder {
tenant_id,
self.authority_host,
);
- credential::CredentialProvider::TokenCredential(
- TokenCache::default(),
- Box::new(client_credential),
- )
+ Arc::new(TokenCredentialProvider::new(
+ client_credential,
+ self.client_options.client()?,
+ self.retry_config.clone(),
+ )) as _
} else if let Some(query_pairs) = self.sas_query_pairs {
- credential::CredentialProvider::SASToken(query_pairs)
+ static_creds(AzureCredential::SASToken(query_pairs))
} else if let Some(sas) = self.sas_key {
- credential::CredentialProvider::SASToken(split_sas(&sas)?)
+ static_creds(AzureCredential::SASToken(split_sas(&sas)?))
} else if self.use_azure_cli.get()? {
- credential::CredentialProvider::TokenCredential(
- TokenCache::default(),
- Box::new(credential::AzureCliCredential::new()),
- )
+ Arc::new(credential::AzureCliCredential::new()) as _
} else {
- let client =
- self.client_options.clone().with_allow_http(true).client()?;
let msi_credential = credential::ImdsManagedIdentityProvider::new(
self.client_id,
self.object_id,
self.msi_resource_id,
self.msi_endpoint,
- client,
);
- credential::CredentialProvider::TokenCredential(
- TokenCache::default(),
- Box::new(msi_credential),
- )
+ Arc::new(TokenCredentialProvider::new(
+ msi_credential,
+ self.client_options.clone().with_allow_http(true).client()?,
+ self.retry_config.clone(),
+ )) as _
};
(false, url, credential, account_name)
};
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index c6a73fe7a..292e4678f 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -32,17 +32,20 @@ pub mod header;
#[cfg(any(feature = "aws", feature = "gcp"))]
pub mod list;
+use async_trait::async_trait;
use std::collections::HashMap;
use std::str::FromStr;
+use std::sync::Arc;
use std::time::Duration;
use reqwest::header::{HeaderMap, HeaderValue};
use reqwest::{Client, ClientBuilder, Proxy, RequestBuilder};
use serde::{Deserialize, Serialize};
+use crate::client::token::{TemporaryToken, TokenCache};
use crate::config::{fmt_duration, ConfigValue};
use crate::path::Path;
-use crate::GetOptions;
+use crate::{GetOptions, Result, RetryConfig};
fn map_client_error(e: reqwest::Error) -> super::Error {
super::Error::Generic {
@@ -503,6 +506,90 @@ impl GetOptionsExt for RequestBuilder {
}
}
+/// Provides credentials for use when signing requests
+#[async_trait]
+pub trait CredentialProvider: std::fmt::Debug + Send + Sync {
+ type Credential;
+
+ async fn get_credential(&self) -> Result<Arc<Self::Credential>>;
+}
+
+/// A static set of credentials
+#[derive(Debug)]
+pub struct StaticCredentialProvider<T> {
+ credential: Arc<T>,
+}
+
+impl<T> StaticCredentialProvider<T> {
+ pub fn new(credential: T) -> Self {
+ Self {
+ credential: Arc::new(credential),
+ }
+ }
+}
+
+#[async_trait]
+impl<T> CredentialProvider for StaticCredentialProvider<T>
+where
+ T: std::fmt::Debug + Send + Sync,
+{
+ type Credential = T;
+
+ async fn get_credential(&self) -> Result<Arc<T>> {
+ Ok(Arc::clone(&self.credential))
+ }
+}
+
+#[cfg(any(feature = "aws", feature = "azure", feature = "gcp"))]
+mod cloud {
+ use super::*;
+
+ /// A [`CredentialProvider`] that uses [`Client`] to fetch temporary tokens
+ #[derive(Debug)]
+ pub struct TokenCredentialProvider<T: TokenProvider> {
+ inner: T,
+ client: Client,
+ retry: RetryConfig,
+ cache: TokenCache<Arc<T::Credential>>,
+ }
+
+ impl<T: TokenProvider> TokenCredentialProvider<T> {
+ pub fn new(inner: T, client: Client, retry: RetryConfig) -> Self {
+ Self {
+ inner,
+ client,
+ retry,
+ cache: Default::default(),
+ }
+ }
+ }
+
+ #[async_trait]
+ impl<T: TokenProvider> CredentialProvider for TokenCredentialProvider<T> {
+ type Credential = T::Credential;
+
+ async fn get_credential(&self) -> Result<Arc<Self::Credential>> {
+ self.cache
+ .get_or_insert_with(|| self.inner.fetch_token(&self.client, &self.retry))
+ .await
+ }
+ }
+
+ #[async_trait]
+ pub trait TokenProvider: std::fmt::Debug + Send + Sync {
+ type Credential: std::fmt::Debug + Send + Sync;
+
+ async fn fetch_token(
+ &self,
+ client: &Client,
+ retry: &RetryConfig,
+ ) -> Result<TemporaryToken<Arc<Self::Credential>>>;
+ }
+}
+
+#[cfg(any(feature = "aws", feature = "azure", feature = "gcp"))]
+pub use cloud::*;
+
#[cfg(test)]
mod tests {
use super::*;
diff --git a/object_store/src/gcp/credential.rs b/object_store/src/gcp/credential.rs
index 057e01333..ad12855e1 100644
--- a/object_store/src/gcp/credential.rs
+++ b/object_store/src/gcp/credential.rs
@@ -17,6 +17,9 @@
use crate::client::retry::RetryExt;
use crate::client::token::TemporaryToken;
+use crate::client::{TokenCredentialProvider, TokenProvider};
+use crate::gcp::credential::Error::UnsupportedCredentialsType;
+use crate::gcp::{GcpCredentialProvider, STORE};
use crate::ClientOptions;
use crate::RetryConfig;
use async_trait::async_trait;
@@ -30,6 +33,7 @@ use std::env;
use std::fs::File;
use std::io::BufReader;
use std::path::{Path, PathBuf};
+use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::info;
@@ -67,9 +71,21 @@ pub enum Error {
#[snafu(display("Unsupported ApplicationCredentials type: {}", type_))]
UnsupportedCredentialsType { type_: String },
+}
+
+impl From<Error> for crate::Error {
+ fn from(value: Error) -> Self {
+ Self::Generic {
+ store: STORE,
+ source: Box::new(value),
+ }
+ }
+}
- #[snafu(display("Error creating client: {}", source))]
- Client { source: crate::Error },
+#[derive(Debug, Eq, PartialEq)]
+pub struct GcpCredential {
+ /// An HTTP bearer token
+ pub bearer: String,
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
@@ -127,15 +143,6 @@ struct TokenResponse {
expires_in: u64,
}
-#[async_trait]
-pub trait TokenProvider: std::fmt::Debug + Send + Sync {
- async fn fetch_token(
- &self,
- client: &Client,
- retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>>;
-}
-
/// Encapsulates the logic to perform an OAuth token challenge
#[derive(Debug)]
pub struct OAuthProvider {
@@ -174,12 +181,14 @@ impl OAuthProvider {
#[async_trait]
impl TokenProvider for OAuthProvider {
+ type Credential = GcpCredential;
+
/// Fetch a fresh token
async fn fetch_token(
&self,
client: &Client,
retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>> {
+ ) -> crate::Result<TemporaryToken<Arc<GcpCredential>>> {
let now = seconds_since_epoch();
let exp = now + 3600;
@@ -221,12 +230,12 @@ impl TokenProvider for OAuthProvider {
.await
.context(TokenResponseBodySnafu)?;
- let token = TemporaryToken {
- token: response.access_token,
+ Ok(TemporaryToken {
+ token: Arc::new(GcpCredential {
+ bearer: response.access_token,
+ }),
expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
- };
-
- Ok(token)
+ })
}
}
@@ -281,17 +290,17 @@ impl ServiceAccountCredentials {
}
/// Create an [`OAuthProvider`] from this credentials struct.
- pub fn token_provider(
+ pub fn oauth_provider(
self,
scope: &str,
audience: &str,
- ) -> Result<Box<dyn TokenProvider>> {
- Ok(Box::new(OAuthProvider::new(
+ ) -> crate::Result<OAuthProvider> {
+ Ok(OAuthProvider::new(
self.client_email,
self.private_key,
scope.to_string(),
audience.to_string(),
- )?) as Box<dyn TokenProvider>)
+ )?)
}
}
@@ -329,23 +338,14 @@ fn b64_encode_obj<T: serde::Serialize>(obj: &T) -> Result<String> {
#[derive(Debug, Default)]
pub struct InstanceCredentialProvider {
audience: String,
- client: Client,
}
impl InstanceCredentialProvider {
/// Create a new [`InstanceCredentialProvider`], we need to control the client in order to enable http access so save the options.
- pub fn new<T: Into<String>>(
- audience: T,
- client_options: ClientOptions,
- ) -> Result<Self> {
- client_options
- .with_allow_http(true)
- .client()
- .map(|client| Self {
- audience: audience.into(),
- client,
- })
- .context(ClientSnafu)
+ pub fn new<T: Into<String>>(audience: T) -> Self {
+ Self {
+ audience: audience.into(),
+ }
}
}
@@ -355,7 +355,7 @@ async fn make_metadata_request(
hostname: &str,
retry: &RetryConfig,
audience: &str,
-) -> Result<TokenResponse> {
+) -> crate::Result<TokenResponse> {
let url = format!(
"http://{hostname}/computeMetadata/v1/instance/service-accounts/default/token"
);
@@ -374,30 +374,29 @@ async fn make_metadata_request(
#[async_trait]
impl TokenProvider for InstanceCredentialProvider {
+ type Credential = GcpCredential;
+
/// Fetch a token from the metadata server.
/// Since the connection is local we need to enable http access and don't actually use the client object passed in.
async fn fetch_token(
&self,
- _client: &Client,
+ client: &Client,
retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>> {
+ ) -> crate::Result<TemporaryToken<Arc<GcpCredential>>> {
const METADATA_IP: &str = "169.254.169.254";
const METADATA_HOST: &str = "metadata";
info!("fetching token from metadata server");
let response =
- make_metadata_request(&self.client, METADATA_HOST, retry, &self.audience)
+ make_metadata_request(client, METADATA_HOST, retry, &self.audience)
.or_else(|_| {
- make_metadata_request(
- &self.client,
- METADATA_IP,
- retry,
- &self.audience,
- )
+ make_metadata_request(client, METADATA_IP, retry, &self.audience)
})
.await?;
let token = TemporaryToken {
- token: response.access_token,
+ token: Arc::new(GcpCredential {
+ bearer: response.access_token,
+ }),
expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
};
Ok(token)
@@ -406,31 +405,35 @@ impl TokenProvider for InstanceCredentialProvider {
/// ApplicationDefaultCredentials
/// <https://google.aip.dev/auth/4110>
-#[derive(Debug)]
-pub enum ApplicationDefaultCredentials {
- /// <https://google.aip.dev/auth/4113>
- AuthorizedUser {
- client_id: String,
- client_secret: String,
- refresh_token: String,
- },
-}
-
-impl ApplicationDefaultCredentials {
- pub fn new(path: Option<&str>) -> Result<Option<Self>, Error> {
- let file = match ApplicationDefaultCredentialsFile::read(path)? {
- Some(f) => f,
- None => return Ok(None),
- };
-
- Ok(Some(match file.type_.as_str() {
- "authorized_user" => Self::AuthorizedUser {
+pub fn application_default_credentials(
+ path: Option<&str>,
+ client: &ClientOptions,
+ retry: &RetryConfig,
+) -> crate::Result<Option<GcpCredentialProvider>> {
+ let file = match ApplicationDefaultCredentialsFile::read(path)? {
+ Some(x) => x,
+ None => return Ok(None),
+ };
+
+ match file.type_.as_str() {
+ // <https://google.aip.dev/auth/4113>
+ "authorized_user" => {
+ let token = AuthorizedUserCredentials {
client_id: file.client_id,
client_secret: file.client_secret,
refresh_token: file.refresh_token,
- },
- type_ => return UnsupportedCredentialsTypeSnafu { type_ }.fail(),
- }))
+ };
+
+ Ok(Some(Arc::new(TokenCredentialProvider::new(
+ token,
+ client.client()?,
+ retry.clone(),
+ ))))
+ }
+ type_ => Err(UnsupportedCredentialsType {
+ type_: type_.to_string(),
+ }
+ .into()),
}
}
@@ -473,41 +476,43 @@ impl ApplicationDefaultCredentialsFile {
const DEFAULT_TOKEN_GCP_URI: &str = "https://accounts.google.com/o/oauth2/token";
+/// <https://google.aip.dev/auth/4113>
+#[derive(Debug)]
+struct AuthorizedUserCredentials {
+ client_id: String,
+ client_secret: String,
+ refresh_token: String,
+}
+
#[async_trait]
-impl TokenProvider for ApplicationDefaultCredentials {
+impl TokenProvider for AuthorizedUserCredentials {
+ type Credential = GcpCredential;
+
async fn fetch_token(
&self,
client: &Client,
retry: &RetryConfig,
- ) -> Result<TemporaryToken<String>, Error> {
- let builder = client.request(Method::POST, DEFAULT_TOKEN_GCP_URI);
- let builder = match self {
- Self::AuthorizedUser {
- client_id,
- client_secret,
- refresh_token,
- } => {
- let body = [
- ("grant_type", "refresh_token"),
- ("client_id", client_id),
- ("client_secret", client_secret),
- ("refresh_token", refresh_token),
- ];
- builder.form(&body)
- }
- };
-
- let response = builder
+ ) -> crate::Result<TemporaryToken<Arc<GcpCredential>>> {
+ let response = client
+ .request(Method::POST, DEFAULT_TOKEN_GCP_URI)
+ .form(&[
+ ("grant_type", "refresh_token"),
+ ("client_id", &self.client_id),
+ ("client_secret", &self.client_secret),
+ ("refresh_token", &self.refresh_token),
+ ])
.send_retry(retry)
.await
.context(TokenRequestSnafu)?
.json::<TokenResponse>()
.await
.context(TokenResponseBodySnafu)?;
- let token = TemporaryToken {
- token: response.access_token,
+
+ Ok(TemporaryToken {
+ token: Arc::new(GcpCredential {
+ bearer: response.access_token,
+ }),
expiry: Some(Instant::now() + Duration::from_secs(response.expires_in)),
- };
- Ok(token)
+ })
}
}
diff --git a/object_store/src/gcp/mod.rs b/object_store/src/gcp/mod.rs
index 32f4055f1..6813bbf6e 100644
--- a/object_store/src/gcp/mod.rs
+++ b/object_store/src/gcp/mod.rs
@@ -48,9 +48,12 @@ use crate::client::header::header_meta;
use crate::client::list::ListResponse;
use crate::client::pagination::stream_paginated;
use crate::client::retry::RetryExt;
-use crate::client::{ClientConfigKey, GetOptionsExt};
+use crate::client::{
+ ClientConfigKey, CredentialProvider, GetOptionsExt, StaticCredentialProvider,
+ TokenCredentialProvider,
+};
+use crate::gcp::credential::{application_default_credentials, GcpCredential};
use crate::{
- client::token::TokenCache,
multipart::{CloudMultiPartUpload, CloudMultiPartUploadImpl, UploadPart},
path::{Path, DELIMITER},
util::format_prefix,
@@ -59,14 +62,15 @@ use crate::{
};
use self::credential::{
- default_gcs_base_url, ApplicationDefaultCredentials, InstanceCredentialProvider,
- ServiceAccountCredentials, TokenProvider,
+ default_gcs_base_url, InstanceCredentialProvider, ServiceAccountCredentials,
};
mod credential;
const STORE: &str = "GCS";
+type GcpCredentialProvider = Arc<dyn CredentialProvider<Credential = GcpCredential>>;
+
#[derive(Debug, Snafu)]
enum Error {
#[snafu(display("Got invalid XML response for {} {}: {}", method, url, source))]
@@ -119,9 +123,6 @@ enum Error {
#[snafu(display("Missing bucket name"))]
MissingBucketName {},
- #[snafu(display("Could not find either metadata credentials or configuration properties to initialize GCS credentials."))]
- MissingCredentials,
-
#[snafu(display(
"One of service account path or service account key may be provided."
))]
@@ -209,8 +210,7 @@ struct GoogleCloudStorageClient {
client: Client,
base_url: String,
- token_provider: Option<Arc<Box<dyn TokenProvider>>>,
- token_cache: TokenCache<String>,
+ credentials: GcpCredentialProvider,
bucket_name: String,
bucket_name_encoded: String,
@@ -223,18 +223,8 @@ struct GoogleCloudStorageClient {
}
impl GoogleCloudStorageClient {
- async fn get_token(&self) -> Result<String> {
- if let Some(token_provider) = &self.token_provider {
- Ok(self
- .token_cache
- .get_or_insert_with(|| {
- token_provider.fetch_token(&self.client, &self.retry_config)
- })
- .await
- .context(CredentialSnafu)?)
- } else {
- Ok("".to_owned())
- }
+ async fn get_credential(&self) -> Result<Arc<GcpCredential>> {
+ self.credentials.get_credential().await
}
fn object_url(&self, path: &Path) -> String {
@@ -249,7 +239,7 @@ impl GoogleCloudStorageClient {
options: GetOptions,
head: bool,
) -> Result<Response> {
- let token = self.get_token().await?;
+ let credential = self.get_credential().await?;
let url = self.object_url(path);
let method = match head {
@@ -260,7 +250,7 @@ impl GoogleCloudStorageClient {
let response = self
.client
.request(method, url)
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.with_get_options(options)
.send_retry(&self.retry_config)
.await
@@ -273,7 +263,7 @@ impl GoogleCloudStorageClient {
/// Perform a put request <https://cloud.google.com/storage/docs/xml-api/put-object-upload>
async fn put_request(&self, path: &Path, payload: Bytes) -> Result<()> {
- let token = self.get_token().await?;
+ let credential = self.get_credential().await?;
let url = self.object_url(path);
let content_type = self
@@ -283,7 +273,7 @@ impl GoogleCloudStorageClient {
self.client
.request(Method::PUT, url)
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_LENGTH, payload.len())
.body(payload)
@@ -298,7 +288,7 @@ impl GoogleCloudStorageClient {
/// Initiate a multi-part upload <https://cloud.google.com/storage/docs/xml-api/post-object-multipart>
async fn multipart_initiate(&self, path: &Path) -> Result<MultipartId> {
- let token = self.get_token().await?;
+ let credential = self.get_credential().await?;
let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path);
let content_type = self
@@ -309,7 +299,7 @@ impl GoogleCloudStorageClient {
let response = self
.client
.request(Method::POST, &url)
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.header(header::CONTENT_TYPE, content_type)
.header(header::CONTENT_LENGTH, "0")
.query(&[("uploads", "")])
@@ -338,12 +328,12 @@ impl GoogleCloudStorageClient {
path: &str,
multipart_id: &MultipartId,
) -> Result<()> {
- let token = self.get_token().await?;
+ let credential = self.get_credential().await?;
let url = format!("{}/{}/{}", self.base_url, self.bucket_name_encoded, path);
self.client
.request(Method::DELETE, &url)
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.header(header::CONTENT_TYPE, "application/octet-stream")
.header(header::CONTENT_LENGTH, "0")
.query(&[("uploadId", multipart_id)])
@@ -356,12 +346,12 @@ impl GoogleCloudStorageClient {
/// Perform a delete request <https://cloud.google.com/storage/docs/xml-api/delete-object>
async fn delete_request(&self, path: &Path) -> Result<()> {
- let token = self.get_token().await?;
+ let credential = self.get_credential().await?;
let url = self.object_url(path);
let builder = self.client.request(Method::DELETE, url);
builder
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.send_retry(&self.retry_config)
.await
.context(DeleteRequestSnafu {
@@ -378,7 +368,7 @@ impl GoogleCloudStorageClient {
to: &Path,
if_not_exists: bool,
) -> Result<()> {
- let token = self.get_token().await?;
+ let credential = self.get_credential().await?;
let url = self.object_url(to);
let from = utf8_percent_encode(from.as_ref(), NON_ALPHANUMERIC);
@@ -394,7 +384,7 @@ impl GoogleCloudStorageClient {
}
builder
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
// Needed if reqwest is compiled with native-tls instead of rustls-tls
// See https://github.com/apache/arrow-rs/pull/3921
.header(header::CONTENT_LENGTH, 0)
@@ -418,7 +408,7 @@ impl GoogleCloudStorageClient {
delimiter: bool,
page_token: Option<&str>,
) -> Result<ListResponse> {
- let token = self.get_token().await?;
+ let credential = self.get_credential().await?;
let url = format!("{}/{}", self.base_url, self.bucket_name_encoded);
let mut query = Vec::with_capacity(5);
@@ -443,7 +433,7 @@ impl GoogleCloudStorageClient {
.client
.request(Method::GET, url)
.query(&query)
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.send_retry(&self.retry_config)
.await
.context(ListRequestSnafu)?
@@ -495,9 +485,9 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
self.client.base_url, self.client.bucket_name_encoded, self.encoded_path
);
- let token = self
+ let credential = self
.client
- .get_token()
+ .get_credential()
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
@@ -505,7 +495,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
.client
.client
.request(Method::PUT, &url)
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.query(&[
("partNumber", format!("{}", part_idx + 1)),
("uploadId", upload_id),
@@ -549,9 +539,9 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
})
.collect();
- let token = self
+ let credential = self
.client
- .get_token()
+ .get_credential()
.await
.map_err(|err| io::Error::new(io::ErrorKind::Other, err))?;
@@ -567,7 +557,7 @@ impl CloudMultiPartUploadImpl for GCSMultipartUpload {
self.client
.client
.request(Method::POST, &url)
- .bearer_auth(token)
+ .bearer_auth(&credential.bearer)
.query(&[("uploadId", upload_id)])
.body(data)
.send_retry(&self.client.retry_config)
@@ -1062,10 +1052,11 @@ impl GoogleCloudStorageBuilder {
};
// Then try to initialize from the application credentials file, or the environment.
- let application_default_credentials = ApplicationDefaultCredentials::new(
+ let application_default_credentials = application_default_credentials(
self.application_credentials_path.as_deref(),
- )
- .context(CredentialSnafu)?;
+ &self.client_options,
+ &self.retry_config,
+ )?;
let disable_oauth = service_account_credentials
.as_ref()
@@ -1081,29 +1072,24 @@ impl GoogleCloudStorageBuilder {
let scope = "https://www.googleapis.com/auth/devstorage.full_control";
let audience = "https://www.googleapis.com/oauth2/v4/token";
- let token_provider = if disable_oauth {
- None
+ let credentials = if disable_oauth {
+ Arc::new(StaticCredentialProvider::new(GcpCredential {
+ bearer: "".to_string(),
+ })) as _
+ } else if let Some(credentials) = service_account_credentials {
+ Arc::new(TokenCredentialProvider::new(
+ credentials.oauth_provider(scope, audience)?,
+ self.client_options.client()?,
+ self.retry_config.clone(),
+ )) as _
+ } else if let Some(credentials) = application_default_credentials {
+ credentials
} else {
- let best_provider = if let Some(credentials) = service_account_credentials {
- Some(
- credentials
- .token_provider(scope, audience)
- .context(CredentialSnafu)?,
- )
- } else if let Some(credentials) = application_default_credentials {
- Some(Box::new(credentials) as Box<dyn TokenProvider>)
- } else {
- Some(Box::new(
- InstanceCredentialProvider::new(
- audience,
- self.client_options.clone(),
- )
- .context(CredentialSnafu)?,
- ) as Box<dyn TokenProvider>)
- };
-
- // A provider is required at this point, bail out if we don't have one.
- Some(best_provider.ok_or(Error::MissingCredentials)?)
+ Arc::new(TokenCredentialProvider::new(
+ InstanceCredentialProvider::new(audience),
+ self.client_options.clone().with_allow_http(true).client()?,
+ self.retry_config.clone(),
+ )) as _
};
let encoded_bucket_name =
@@ -1113,8 +1099,7 @@ impl GoogleCloudStorageBuilder {
client: Arc::new(GoogleCloudStorageClient {
client,
base_url: gcs_base_url,
- token_provider: token_provider.map(Arc::new),
- token_cache: Default::default(),
+ credentials,
bucket_name,
bucket_name_encoded: encoded_bucket_name,
retry_config: self.retry_config,