You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/08/30 16:42:30 UTC
[arrow-rs] branch master updated: Add IMDSv1 fallback (#2609) (#2610)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 62eeaa5eb Add IMDSv1 fallback (#2609) (#2610)
62eeaa5eb is described below
commit 62eeaa5ebd59ac611b8d17f2fc26373fc30af53f
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue Aug 30 17:42:26 2022 +0100
Add IMDSv1 fallback (#2609) (#2610)
* Add IMDSv1 fallback (#2609)
* Add config option
---
object_store/src/aws/credential.rs | 165 +++++++++++++++++++++++++++------
object_store/src/aws/mod.rs | 19 ++++
object_store/src/client/mock_server.rs | 105 +++++++++++++++++++++
object_store/src/client/mod.rs | 2 +
object_store/src/client/retry.rs | 60 +++---------
5 files changed, 276 insertions(+), 75 deletions(-)
diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs
index e6c1bdd74..1abf42be9 100644
--- a/object_store/src/aws/credential.rs
+++ b/object_store/src/aws/credential.rs
@@ -23,11 +23,12 @@ use bytes::Buf;
use chrono::{DateTime, Utc};
use futures::TryFutureExt;
use reqwest::header::{HeaderMap, HeaderValue};
-use reqwest::{Client, Method, Request, RequestBuilder};
+use reqwest::{Client, Method, Request, RequestBuilder, StatusCode};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::sync::Arc;
use std::time::Instant;
+use tracing::warn;
type StdError = Box<dyn std::error::Error + Send + Sync>;
@@ -284,6 +285,7 @@ pub struct InstanceCredentialProvider {
pub cache: TokenCache<Arc<AwsCredential>>,
pub client: Client,
pub retry_config: RetryConfig,
+ pub imdsv1_fallback: bool,
}
impl InstanceCredentialProvider {
@@ -291,11 +293,16 @@ impl InstanceCredentialProvider {
self.cache
.get_or_insert_with(|| {
const METADATA_ENDPOINT: &str = "http://169.254.169.254";
- instance_creds(&self.client, &self.retry_config, METADATA_ENDPOINT)
- .map_err(|source| crate::Error::Generic {
- store: "S3",
- source,
- })
+ instance_creds(
+ &self.client,
+ &self.retry_config,
+ METADATA_ENDPOINT,
+ self.imdsv1_fallback,
+ )
+ .map_err(|source| crate::Error::Generic {
+ store: "S3",
+ source,
+ })
})
.await
}
@@ -360,36 +367,47 @@ async fn instance_creds(
client: &Client,
retry_config: &RetryConfig,
endpoint: &str,
+ imdsv1_fallback: bool,
) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials";
const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";
let token_url = format!("{}/latest/api/token", endpoint);
- let token = client
+
+ let token_result = client
.request(Method::PUT, token_url)
.header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL
.send_retry(retry_config)
- .await?
- .text()
- .await?;
+ .await;
+
+ let token = match token_result {
+ Ok(t) => Some(t.text().await?),
+ Err(e)
+ if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) =>
+ {
+ warn!("received 403 from metadata endpoint, falling back to IMDSv1");
+ None
+ }
+ Err(e) => return Err(e.into()),
+ };
let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH);
- let role = client
- .request(Method::GET, role_url)
- .header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
- .send_retry(retry_config)
- .await?
- .text()
- .await?;
+ let mut role_request = client.request(Method::GET, role_url);
+
+ if let Some(token) = &token {
+ role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
+ }
+
+ let role = role_request.send_retry(retry_config).await?.text().await?;
let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role);
- let creds: InstanceCredentials = client
- .request(Method::GET, creds_url)
- .header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
- .send_retry(retry_config)
- .await?
- .json()
- .await?;
+ let mut creds_request = client.request(Method::GET, creds_url);
+ if let Some(token) = &token {
+ creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
+ }
+
+ let creds: InstanceCredentials =
+ creds_request.send_retry(retry_config).await?.json().await?;
let now = Utc::now();
let ttl = (creds.expiration - now).to_std().unwrap_or_default();
@@ -470,6 +488,8 @@ async fn web_identity(
#[cfg(test)]
mod tests {
use super::*;
+ use crate::client::mock_server::MockServer;
+ use hyper::{Body, Response};
use reqwest::{Client, Method};
use std::env;
@@ -567,11 +587,11 @@ mod tests {
assert_eq!(
resp.status(),
- reqwest::StatusCode::UNAUTHORIZED,
+ StatusCode::UNAUTHORIZED,
"Ensure metadata endpoint is set to only allow IMDSv2"
);
- let creds = instance_creds(&client, &retry_config, &endpoint)
+ let creds = instance_creds(&client, &retry_config, &endpoint, false)
.await
.unwrap();
@@ -583,4 +603,97 @@ mod tests {
assert!(!secret.is_empty());
assert!(!token.is_empty())
}
+
+ #[tokio::test]
+ async fn test_mock() {
+ let server = MockServer::new();
+
+ const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token";
+
+ let secret_access_key = "SECRET";
+ let access_key_id = "KEYID";
+ let token = "TOKEN";
+
+ let endpoint = server.url();
+ let client = Client::new();
+ let retry_config = RetryConfig::default();
+
+ // Test IMDSv2
+ server.push_fn(|req| {
+ assert_eq!(req.uri().path(), "/latest/api/token");
+ assert_eq!(req.method(), &Method::PUT);
+ Response::new(Body::from("cupcakes"))
+ });
+ server.push_fn(|req| {
+ assert_eq!(
+ req.uri().path(),
+ "/latest/meta-data/iam/security-credentials/"
+ );
+ assert_eq!(req.method(), &Method::GET);
+ let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
+ assert_eq!(t, "cupcakes");
+ Response::new(Body::from("myrole"))
+ });
+ server.push_fn(|req| {
+ assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
+ assert_eq!(req.method(), &Method::GET);
+ let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
+ assert_eq!(t, "cupcakes");
+ Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
+ });
+
+ let creds = instance_creds(&client, &retry_config, endpoint, true)
+ .await
+ .unwrap();
+
+ assert_eq!(creds.token.token.as_deref().unwrap(), token);
+ assert_eq!(&creds.token.key_id, access_key_id);
+ assert_eq!(&creds.token.secret_key, secret_access_key);
+
+ // Test IMDSv1 fallback
+ server.push_fn(|req| {
+ assert_eq!(req.uri().path(), "/latest/api/token");
+ assert_eq!(req.method(), &Method::PUT);
+ Response::builder()
+ .status(StatusCode::FORBIDDEN)
+ .body(Body::empty())
+ .unwrap()
+ });
+ server.push_fn(|req| {
+ assert_eq!(
+ req.uri().path(),
+ "/latest/meta-data/iam/security-credentials/"
+ );
+ assert_eq!(req.method(), &Method::GET);
+ assert!(req.headers().get(IMDSV2_HEADER).is_none());
+ Response::new(Body::from("myrole"))
+ });
+ server.push_fn(|req| {
+ assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
+ assert_eq!(req.method(), &Method::GET);
+ assert!(req.headers().get(IMDSV2_HEADER).is_none());
+ Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
+ });
+
+ let creds = instance_creds(&client, &retry_config, endpoint, true)
+ .await
+ .unwrap();
+
+ assert_eq!(creds.token.token.as_deref().unwrap(), token);
+ assert_eq!(&creds.token.key_id, access_key_id);
+ assert_eq!(&creds.token.secret_key, secret_access_key);
+
+ // Test IMDSv1 fallback disabled
+ server.push(
+ Response::builder()
+ .status(StatusCode::FORBIDDEN)
+ .body(Body::empty())
+ .unwrap(),
+ );
+
+ // Should fail
+ instance_creds(&client, &retry_config, endpoint, false)
+ .await
+ .unwrap_err();
+ }
}
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index ab90afa5d..d1d0a12cd 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -339,6 +339,7 @@ pub struct AmazonS3Builder {
token: Option<String>,
retry_config: RetryConfig,
allow_http: bool,
+ imdsv1_fallback: bool,
}
impl AmazonS3Builder {
@@ -446,6 +447,23 @@ impl AmazonS3Builder {
self
}
+ /// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends
+ /// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack]
+ ///
+ /// However, certain deployment environments, such as those running old versions of kube2iam,
+ /// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1
+ /// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported.
+ ///
+ /// This option has no effect if not using instance credentials
+ ///
+ /// [IMDSv2]: [https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html]
+ /// [SSRF attack]: [https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/]
+ ///
+ pub fn with_imdsv1_fallback(mut self) -> Self {
+ self.imdsv1_fallback = true;
+ self
+ }
+
/// Create a [`AmazonS3`] instance from the provided values,
/// consuming `self`.
pub fn build(self) -> Result<AmazonS3> {
@@ -503,6 +521,7 @@ impl AmazonS3Builder {
cache: Default::default(),
client,
retry_config: self.retry_config.clone(),
+ imdsv1_fallback: self.imdsv1_fallback,
})
}
},
diff --git a/object_store/src/client/mock_server.rs b/object_store/src/client/mock_server.rs
new file mode 100644
index 000000000..adb7e0fff
--- /dev/null
+++ b/object_store/src/client/mock_server.rs
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use hyper::service::{make_service_fn, service_fn};
+use hyper::{Body, Request, Response, Server};
+use parking_lot::Mutex;
+use std::collections::VecDeque;
+use std::convert::Infallible;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use tokio::sync::oneshot;
+use tokio::task::JoinHandle;
+
+pub type ResponseFn = Box<dyn FnOnce(Request<Body>) -> Response<Body> + Send>;
+
+/// A mock server
+pub struct MockServer {
+ responses: Arc<Mutex<VecDeque<ResponseFn>>>,
+ shutdown: oneshot::Sender<()>,
+ handle: JoinHandle<()>,
+ url: String,
+}
+
+impl MockServer {
+ pub fn new() -> Self {
+ let responses: Arc<Mutex<VecDeque<ResponseFn>>> =
+ Arc::new(Mutex::new(VecDeque::with_capacity(10)));
+
+ let r = Arc::clone(&responses);
+ let make_service = make_service_fn(move |_conn| {
+ let r = Arc::clone(&r);
+ async move {
+ Ok::<_, Infallible>(service_fn(move |req| {
+ let r = Arc::clone(&r);
+ async move {
+ Ok::<_, Infallible>(match r.lock().pop_front() {
+ Some(r) => r(req),
+ None => Response::new(Body::from("Hello World")),
+ })
+ }
+ }))
+ }
+ });
+
+ let (shutdown, rx) = oneshot::channel::<()>();
+ let server =
+ Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service);
+
+ let url = format!("http://{}", server.local_addr());
+
+ let handle = tokio::spawn(async move {
+ server
+ .with_graceful_shutdown(async {
+ rx.await.ok();
+ })
+ .await
+ .unwrap()
+ });
+
+ Self {
+ responses,
+ shutdown,
+ handle,
+ url,
+ }
+ }
+
+ /// The url of the mock server
+ pub fn url(&self) -> &str {
+ &self.url
+ }
+
+ /// Add a response
+ pub fn push(&self, response: Response<Body>) {
+ self.push_fn(|_| response)
+ }
+
+ /// Add a response function
+ pub fn push_fn<F>(&self, f: F)
+ where
+ F: FnOnce(Request<Body>) -> Response<Body> + Send + 'static,
+ {
+ self.responses.lock().push_back(Box::new(f))
+ }
+
+ /// Shutdown the mock server
+ pub async fn shutdown(self) {
+ let _ = self.shutdown.send(());
+ self.handle.await.unwrap()
+ }
+}
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index e6de3e929..c93c68a1f 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -18,6 +18,8 @@
//! Generic utilities reqwest based ObjectStore implementations
pub mod backoff;
+#[cfg(test)]
+pub mod mock_server;
pub mod pagination;
pub mod retry;
pub mod token;
diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs
index 44d7835a5..d66628aec 100644
--- a/object_store/src/client/retry.rs
+++ b/object_store/src/client/retry.rs
@@ -180,54 +180,17 @@ impl RetryExt for reqwest::RequestBuilder {
#[cfg(test)]
mod tests {
+ use crate::client::mock_server::MockServer;
use crate::client::retry::RetryExt;
use crate::RetryConfig;
use hyper::header::LOCATION;
- use hyper::service::{make_service_fn, service_fn};
- use hyper::{Body, Response, Server};
- use parking_lot::Mutex;
+ use hyper::{Body, Response};
use reqwest::{Client, Method, StatusCode};
- use std::collections::VecDeque;
- use std::convert::Infallible;
- use std::net::SocketAddr;
- use std::sync::Arc;
use std::time::Duration;
#[tokio::test]
async fn test_retry() {
- let responses: Arc<Mutex<VecDeque<Response<Body>>>> =
- Arc::new(Mutex::new(VecDeque::with_capacity(10)));
-
- let r = Arc::clone(&responses);
- let make_service = make_service_fn(move |_conn| {
- let r = Arc::clone(&r);
- async move {
- Ok::<_, Infallible>(service_fn(move |_req| {
- let r = Arc::clone(&r);
- async move {
- Ok::<_, Infallible>(match r.lock().pop_front() {
- Some(r) => r,
- None => Response::new(Body::from("Hello World")),
- })
- }
- }))
- }
- });
-
- let (tx, rx) = tokio::sync::oneshot::channel::<()>();
- let server =
- Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service);
-
- let url = format!("http://{}", server.local_addr());
-
- let server_handle = tokio::spawn(async move {
- server
- .with_graceful_shutdown(async {
- rx.await.ok();
- })
- .await
- .unwrap()
- });
+ let mock = MockServer::new();
let retry = RetryConfig {
backoff: Default::default(),
@@ -236,14 +199,14 @@ mod tests {
};
let client = Client::new();
- let do_request = || client.request(Method::GET, &url).send_retry(&retry);
+ let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry);
// Simple request should work
let r = do_request().await.unwrap();
assert_eq!(r.status(), StatusCode::OK);
// Returns client errors immediately with status message
- responses.lock().push_back(
+ mock.push(
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::from("cupcakes"))
@@ -256,7 +219,7 @@ mod tests {
assert_eq!(&e.message, "cupcakes");
// Handles client errors with no payload
- responses.lock().push_back(
+ mock.push(
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(Body::empty())
@@ -269,7 +232,7 @@ mod tests {
assert_eq!(&e.message, "No Body");
// Should retry server error request
- responses.lock().push_back(
+ mock.push(
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::empty())
@@ -280,7 +243,7 @@ mod tests {
assert_eq!(r.status(), StatusCode::OK);
// Accepts 204 status code
- responses.lock().push_back(
+ mock.push(
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(Body::empty())
@@ -291,7 +254,7 @@ mod tests {
assert_eq!(r.status(), StatusCode::NO_CONTENT);
// Follows redirects
- responses.lock().push_back(
+ mock.push(
Response::builder()
.status(StatusCode::FOUND)
.header(LOCATION, "/foo")
@@ -305,7 +268,7 @@ mod tests {
// Gives up after the retrying the specified number of times
for _ in 0..=retry.max_retries {
- responses.lock().push_back(
+ mock.push(
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(Body::from("ignored"))
@@ -318,7 +281,6 @@ mod tests {
assert_eq!(e.message, "502 Bad Gateway");
// Shutdown
- let _ = tx.send(());
- server_handle.await.unwrap();
+ mock.shutdown().await
}
}