You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2022/08/30 16:42:30 UTC

[arrow-rs] branch master updated: Add IMDSv1 fallback (#2609) (#2610)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new 62eeaa5eb Add IMDSv1 fallback (#2609) (#2610)
62eeaa5eb is described below

commit 62eeaa5ebd59ac611b8d17f2fc26373fc30af53f
Author: Raphael Taylor-Davies <17...@users.noreply.github.com>
AuthorDate: Tue Aug 30 17:42:26 2022 +0100

    Add IMDSv1 fallback (#2609) (#2610)
    
    * Add IMDSv1 fallback (#2609)
    
    * Add config option
---
 object_store/src/aws/credential.rs     | 165 +++++++++++++++++++++++++++------
 object_store/src/aws/mod.rs            |  19 ++++
 object_store/src/client/mock_server.rs | 105 +++++++++++++++++++++
 object_store/src/client/mod.rs         |   2 +
 object_store/src/client/retry.rs       |  60 +++---------
 5 files changed, 276 insertions(+), 75 deletions(-)

diff --git a/object_store/src/aws/credential.rs b/object_store/src/aws/credential.rs
index e6c1bdd74..1abf42be9 100644
--- a/object_store/src/aws/credential.rs
+++ b/object_store/src/aws/credential.rs
@@ -23,11 +23,12 @@ use bytes::Buf;
 use chrono::{DateTime, Utc};
 use futures::TryFutureExt;
 use reqwest::header::{HeaderMap, HeaderValue};
-use reqwest::{Client, Method, Request, RequestBuilder};
+use reqwest::{Client, Method, Request, RequestBuilder, StatusCode};
 use serde::Deserialize;
 use std::collections::BTreeMap;
 use std::sync::Arc;
 use std::time::Instant;
+use tracing::warn;
 
 type StdError = Box<dyn std::error::Error + Send + Sync>;
 
@@ -284,6 +285,7 @@ pub struct InstanceCredentialProvider {
     pub cache: TokenCache<Arc<AwsCredential>>,
     pub client: Client,
     pub retry_config: RetryConfig,
+    pub imdsv1_fallback: bool,
 }
 
 impl InstanceCredentialProvider {
@@ -291,11 +293,16 @@ impl InstanceCredentialProvider {
         self.cache
             .get_or_insert_with(|| {
                 const METADATA_ENDPOINT: &str = "http://169.254.169.254";
-                instance_creds(&self.client, &self.retry_config, METADATA_ENDPOINT)
-                    .map_err(|source| crate::Error::Generic {
-                        store: "S3",
-                        source,
-                    })
+                instance_creds(
+                    &self.client,
+                    &self.retry_config,
+                    METADATA_ENDPOINT,
+                    self.imdsv1_fallback,
+                )
+                .map_err(|source| crate::Error::Generic {
+                    store: "S3",
+                    source,
+                })
             })
             .await
     }
@@ -360,36 +367,47 @@ async fn instance_creds(
     client: &Client,
     retry_config: &RetryConfig,
     endpoint: &str,
+    imdsv1_fallback: bool,
 ) -> Result<TemporaryToken<Arc<AwsCredential>>, StdError> {
     const CREDENTIALS_PATH: &str = "latest/meta-data/iam/security-credentials";
     const AWS_EC2_METADATA_TOKEN_HEADER: &str = "X-aws-ec2-metadata-token";
 
     let token_url = format!("{}/latest/api/token", endpoint);
-    let token = client
+
+    let token_result = client
         .request(Method::PUT, token_url)
         .header("X-aws-ec2-metadata-token-ttl-seconds", "600") // 10 minute TTL
         .send_retry(retry_config)
-        .await?
-        .text()
-        .await?;
+        .await;
+
+    let token = match token_result {
+        Ok(t) => Some(t.text().await?),
+        Err(e)
+            if imdsv1_fallback && matches!(e.status(), Some(StatusCode::FORBIDDEN)) =>
+        {
+            warn!("received 403 from metadata endpoint, falling back to IMDSv1");
+            None
+        }
+        Err(e) => return Err(e.into()),
+    };
 
     let role_url = format!("{}/{}/", endpoint, CREDENTIALS_PATH);
-    let role = client
-        .request(Method::GET, role_url)
-        .header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
-        .send_retry(retry_config)
-        .await?
-        .text()
-        .await?;
+    let mut role_request = client.request(Method::GET, role_url);
+
+    if let Some(token) = &token {
+        role_request = role_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
+    }
+
+    let role = role_request.send_retry(retry_config).await?.text().await?;
 
     let creds_url = format!("{}/{}/{}", endpoint, CREDENTIALS_PATH, role);
-    let creds: InstanceCredentials = client
-        .request(Method::GET, creds_url)
-        .header(AWS_EC2_METADATA_TOKEN_HEADER, &token)
-        .send_retry(retry_config)
-        .await?
-        .json()
-        .await?;
+    let mut creds_request = client.request(Method::GET, creds_url);
+    if let Some(token) = &token {
+        creds_request = creds_request.header(AWS_EC2_METADATA_TOKEN_HEADER, token);
+    }
+
+    let creds: InstanceCredentials =
+        creds_request.send_retry(retry_config).await?.json().await?;
 
     let now = Utc::now();
     let ttl = (creds.expiration - now).to_std().unwrap_or_default();
@@ -470,6 +488,8 @@ async fn web_identity(
 #[cfg(test)]
 mod tests {
     use super::*;
+    use crate::client::mock_server::MockServer;
+    use hyper::{Body, Response};
     use reqwest::{Client, Method};
     use std::env;
 
@@ -567,11 +587,11 @@ mod tests {
 
         assert_eq!(
             resp.status(),
-            reqwest::StatusCode::UNAUTHORIZED,
+            StatusCode::UNAUTHORIZED,
             "Ensure metadata endpoint is set to only allow IMDSv2"
         );
 
-        let creds = instance_creds(&client, &retry_config, &endpoint)
+        let creds = instance_creds(&client, &retry_config, &endpoint, false)
             .await
             .unwrap();
 
@@ -583,4 +603,97 @@ mod tests {
         assert!(!secret.is_empty());
         assert!(!token.is_empty())
     }
+
+    #[tokio::test]
+    async fn test_mock() {
+        let server = MockServer::new();
+
+        const IMDSV2_HEADER: &str = "X-aws-ec2-metadata-token";
+
+        let secret_access_key = "SECRET";
+        let access_key_id = "KEYID";
+        let token = "TOKEN";
+
+        let endpoint = server.url();
+        let client = Client::new();
+        let retry_config = RetryConfig::default();
+
+        // Test IMDSv2
+        server.push_fn(|req| {
+            assert_eq!(req.uri().path(), "/latest/api/token");
+            assert_eq!(req.method(), &Method::PUT);
+            Response::new(Body::from("cupcakes"))
+        });
+        server.push_fn(|req| {
+            assert_eq!(
+                req.uri().path(),
+                "/latest/meta-data/iam/security-credentials/"
+            );
+            assert_eq!(req.method(), &Method::GET);
+            let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
+            assert_eq!(t, "cupcakes");
+            Response::new(Body::from("myrole"))
+        });
+        server.push_fn(|req| {
+            assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
+            assert_eq!(req.method(), &Method::GET);
+            let t = req.headers().get(IMDSV2_HEADER).unwrap().to_str().unwrap();
+            assert_eq!(t, "cupcakes");
+            Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
+        });
+
+        let creds = instance_creds(&client, &retry_config, endpoint, true)
+            .await
+            .unwrap();
+
+        assert_eq!(creds.token.token.as_deref().unwrap(), token);
+        assert_eq!(&creds.token.key_id, access_key_id);
+        assert_eq!(&creds.token.secret_key, secret_access_key);
+
+        // Test IMDSv1 fallback
+        server.push_fn(|req| {
+            assert_eq!(req.uri().path(), "/latest/api/token");
+            assert_eq!(req.method(), &Method::PUT);
+            Response::builder()
+                .status(StatusCode::FORBIDDEN)
+                .body(Body::empty())
+                .unwrap()
+        });
+        server.push_fn(|req| {
+            assert_eq!(
+                req.uri().path(),
+                "/latest/meta-data/iam/security-credentials/"
+            );
+            assert_eq!(req.method(), &Method::GET);
+            assert!(req.headers().get(IMDSV2_HEADER).is_none());
+            Response::new(Body::from("myrole"))
+        });
+        server.push_fn(|req| {
+            assert_eq!(req.uri().path(), "/latest/meta-data/iam/security-credentials/myrole");
+            assert_eq!(req.method(), &Method::GET);
+            assert!(req.headers().get(IMDSV2_HEADER).is_none());
+            Response::new(Body::from(r#"{"AccessKeyId":"KEYID","Code":"Success","Expiration":"2022-08-30T10:51:04Z","LastUpdated":"2022-08-30T10:21:04Z","SecretAccessKey":"SECRET","Token":"TOKEN","Type":"AWS-HMAC"}"#))
+        });
+
+        let creds = instance_creds(&client, &retry_config, endpoint, true)
+            .await
+            .unwrap();
+
+        assert_eq!(creds.token.token.as_deref().unwrap(), token);
+        assert_eq!(&creds.token.key_id, access_key_id);
+        assert_eq!(&creds.token.secret_key, secret_access_key);
+
+        // Test IMDSv1 fallback disabled
+        server.push(
+            Response::builder()
+                .status(StatusCode::FORBIDDEN)
+                .body(Body::empty())
+                .unwrap(),
+        );
+
+        // Should fail
+        instance_creds(&client, &retry_config, endpoint, false)
+            .await
+            .unwrap_err();
+    }
 }
diff --git a/object_store/src/aws/mod.rs b/object_store/src/aws/mod.rs
index ab90afa5d..d1d0a12cd 100644
--- a/object_store/src/aws/mod.rs
+++ b/object_store/src/aws/mod.rs
@@ -339,6 +339,7 @@ pub struct AmazonS3Builder {
     token: Option<String>,
     retry_config: RetryConfig,
     allow_http: bool,
+    imdsv1_fallback: bool,
 }
 
 impl AmazonS3Builder {
@@ -446,6 +447,23 @@ impl AmazonS3Builder {
         self
     }
 
+    /// By default instance credentials will only be fetched over [IMDSv2], as AWS recommends
+    /// against having IMDSv1 enabled on EC2 instances as it is vulnerable to [SSRF attack]
+    ///
+    /// However, certain deployment environments, such as those running old versions of kube2iam,
+    /// may not support IMDSv2. This option will enable automatic fallback to using IMDSv1
+    /// if the token endpoint returns a 403 error indicating that IMDSv2 is not supported.
+    ///
+    /// This option has no effect if not using instance credentials
+    ///
+    /// [IMDSv2]: [https://docs.aws.amazon.com/AWSEC2/latest/UserGuide/configuring-instance-metadata-service.html]
+    /// [SSRF attack]: [https://aws.amazon.com/blogs/security/defense-in-depth-open-firewalls-reverse-proxies-ssrf-vulnerabilities-ec2-instance-metadata-service/]
+    ///
+    pub fn with_imdsv1_fallback(mut self) -> Self {
+        self.imdsv1_fallback = true;
+        self
+    }
+
     /// Create a [`AmazonS3`] instance from the provided values,
     /// consuming `self`.
     pub fn build(self) -> Result<AmazonS3> {
@@ -503,6 +521,7 @@ impl AmazonS3Builder {
                         cache: Default::default(),
                         client,
                         retry_config: self.retry_config.clone(),
+                        imdsv1_fallback: self.imdsv1_fallback,
                     })
                 }
             },
diff --git a/object_store/src/client/mock_server.rs b/object_store/src/client/mock_server.rs
new file mode 100644
index 000000000..adb7e0fff
--- /dev/null
+++ b/object_store/src/client/mock_server.rs
@@ -0,0 +1,105 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use hyper::service::{make_service_fn, service_fn};
+use hyper::{Body, Request, Response, Server};
+use parking_lot::Mutex;
+use std::collections::VecDeque;
+use std::convert::Infallible;
+use std::net::SocketAddr;
+use std::sync::Arc;
+use tokio::sync::oneshot;
+use tokio::task::JoinHandle;
+
+pub type ResponseFn = Box<dyn FnOnce(Request<Body>) -> Response<Body> + Send>;
+
+/// A mock server
+pub struct MockServer {
+    responses: Arc<Mutex<VecDeque<ResponseFn>>>,
+    shutdown: oneshot::Sender<()>,
+    handle: JoinHandle<()>,
+    url: String,
+}
+
+impl MockServer {
+    pub fn new() -> Self {
+        let responses: Arc<Mutex<VecDeque<ResponseFn>>> =
+            Arc::new(Mutex::new(VecDeque::with_capacity(10)));
+
+        let r = Arc::clone(&responses);
+        let make_service = make_service_fn(move |_conn| {
+            let r = Arc::clone(&r);
+            async move {
+                Ok::<_, Infallible>(service_fn(move |req| {
+                    let r = Arc::clone(&r);
+                    async move {
+                        Ok::<_, Infallible>(match r.lock().pop_front() {
+                            Some(r) => r(req),
+                            None => Response::new(Body::from("Hello World")),
+                        })
+                    }
+                }))
+            }
+        });
+
+        let (shutdown, rx) = oneshot::channel::<()>();
+        let server =
+            Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service);
+
+        let url = format!("http://{}", server.local_addr());
+
+        let handle = tokio::spawn(async move {
+            server
+                .with_graceful_shutdown(async {
+                    rx.await.ok();
+                })
+                .await
+                .unwrap()
+        });
+
+        Self {
+            responses,
+            shutdown,
+            handle,
+            url,
+        }
+    }
+
+    /// The url of the mock server
+    pub fn url(&self) -> &str {
+        &self.url
+    }
+
+    /// Add a response
+    pub fn push(&self, response: Response<Body>) {
+        self.push_fn(|_| response)
+    }
+
+    /// Add a response function
+    pub fn push_fn<F>(&self, f: F)
+    where
+        F: FnOnce(Request<Body>) -> Response<Body> + Send + 'static,
+    {
+        self.responses.lock().push_back(Box::new(f))
+    }
+
+    /// Shutdown the mock server
+    pub async fn shutdown(self) {
+        let _ = self.shutdown.send(());
+        self.handle.await.unwrap()
+    }
+}
diff --git a/object_store/src/client/mod.rs b/object_store/src/client/mod.rs
index e6de3e929..c93c68a1f 100644
--- a/object_store/src/client/mod.rs
+++ b/object_store/src/client/mod.rs
@@ -18,6 +18,8 @@
 //! Generic utilities reqwest based ObjectStore implementations
 
 pub mod backoff;
+#[cfg(test)]
+pub mod mock_server;
 pub mod pagination;
 pub mod retry;
 pub mod token;
diff --git a/object_store/src/client/retry.rs b/object_store/src/client/retry.rs
index 44d7835a5..d66628aec 100644
--- a/object_store/src/client/retry.rs
+++ b/object_store/src/client/retry.rs
@@ -180,54 +180,17 @@ impl RetryExt for reqwest::RequestBuilder {
 
 #[cfg(test)]
 mod tests {
+    use crate::client::mock_server::MockServer;
     use crate::client::retry::RetryExt;
     use crate::RetryConfig;
     use hyper::header::LOCATION;
-    use hyper::service::{make_service_fn, service_fn};
-    use hyper::{Body, Response, Server};
-    use parking_lot::Mutex;
+    use hyper::{Body, Response};
     use reqwest::{Client, Method, StatusCode};
-    use std::collections::VecDeque;
-    use std::convert::Infallible;
-    use std::net::SocketAddr;
-    use std::sync::Arc;
     use std::time::Duration;
 
     #[tokio::test]
     async fn test_retry() {
-        let responses: Arc<Mutex<VecDeque<Response<Body>>>> =
-            Arc::new(Mutex::new(VecDeque::with_capacity(10)));
-
-        let r = Arc::clone(&responses);
-        let make_service = make_service_fn(move |_conn| {
-            let r = Arc::clone(&r);
-            async move {
-                Ok::<_, Infallible>(service_fn(move |_req| {
-                    let r = Arc::clone(&r);
-                    async move {
-                        Ok::<_, Infallible>(match r.lock().pop_front() {
-                            Some(r) => r,
-                            None => Response::new(Body::from("Hello World")),
-                        })
-                    }
-                }))
-            }
-        });
-
-        let (tx, rx) = tokio::sync::oneshot::channel::<()>();
-        let server =
-            Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))).serve(make_service);
-
-        let url = format!("http://{}", server.local_addr());
-
-        let server_handle = tokio::spawn(async move {
-            server
-                .with_graceful_shutdown(async {
-                    rx.await.ok();
-                })
-                .await
-                .unwrap()
-        });
+        let mock = MockServer::new();
 
         let retry = RetryConfig {
             backoff: Default::default(),
@@ -236,14 +199,14 @@ mod tests {
         };
 
         let client = Client::new();
-        let do_request = || client.request(Method::GET, &url).send_retry(&retry);
+        let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry);
 
         // Simple request should work
         let r = do_request().await.unwrap();
         assert_eq!(r.status(), StatusCode::OK);
 
         // Returns client errors immediately with status message
-        responses.lock().push_back(
+        mock.push(
             Response::builder()
                 .status(StatusCode::BAD_REQUEST)
                 .body(Body::from("cupcakes"))
@@ -256,7 +219,7 @@ mod tests {
         assert_eq!(&e.message, "cupcakes");
 
         // Handles client errors with no payload
-        responses.lock().push_back(
+        mock.push(
             Response::builder()
                 .status(StatusCode::BAD_REQUEST)
                 .body(Body::empty())
@@ -269,7 +232,7 @@ mod tests {
         assert_eq!(&e.message, "No Body");
 
         // Should retry server error request
-        responses.lock().push_back(
+        mock.push(
             Response::builder()
                 .status(StatusCode::BAD_GATEWAY)
                 .body(Body::empty())
@@ -280,7 +243,7 @@ mod tests {
         assert_eq!(r.status(), StatusCode::OK);
 
         // Accepts 204 status code
-        responses.lock().push_back(
+        mock.push(
             Response::builder()
                 .status(StatusCode::NO_CONTENT)
                 .body(Body::empty())
@@ -291,7 +254,7 @@ mod tests {
         assert_eq!(r.status(), StatusCode::NO_CONTENT);
 
         // Follows redirects
-        responses.lock().push_back(
+        mock.push(
             Response::builder()
                 .status(StatusCode::FOUND)
                 .header(LOCATION, "/foo")
@@ -305,7 +268,7 @@ mod tests {
 
         // Gives up after the retrying the specified number of times
         for _ in 0..=retry.max_retries {
-            responses.lock().push_back(
+            mock.push(
                 Response::builder()
                     .status(StatusCode::BAD_GATEWAY)
                     .body(Body::from("ignored"))
@@ -318,7 +281,6 @@ mod tests {
         assert_eq!(e.message, "502 Bad Gateway");
 
         // Shutdown
-        let _ = tx.send(());
-        server_handle.await.unwrap();
+        mock.shutdown().await
     }
 }